Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed unhelpful error message in from_thread_run functions. #1513

Merged
merged 16 commits into from
May 21, 2020
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 5 additions & 82 deletions trio/_core/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from sniffio import current_async_library_cvar

import attr
from async_generator import isasyncgen
from sortedcontainers import SortedDict
from outcome import Error, Value, capture

Expand All @@ -36,7 +35,7 @@
)
from .. import _core
from .._deprecate import deprecated
from .._util import Final, NoPublicConstructor
from .._util import Final, NoPublicConstructor, coroutine_or_error

_NO_SEND = object()

Expand Down Expand Up @@ -1247,86 +1246,7 @@ def spawn_impl(self, async_fn, args, nursery, name, *, system_task=False):
# Call the function and get the coroutine object, while giving helpful
guilledk marked this conversation as resolved.
Show resolved Hide resolved
# errors for common mistakes.
######

def _return_value_looks_like_wrong_library(value):
# Returned by legacy @asyncio.coroutine functions, which includes
# a surprising proportion of asyncio builtins.
if isinstance(value, collections.abc.Generator):
return True
# The protocol for detecting an asyncio Future-like object
if getattr(value, "_asyncio_future_blocking", None) is not None:
return True
# This janky check catches tornado Futures and twisted Deferreds.
# By the time we're calling this function, we already know
# something has gone wrong, so a heuristic is pretty safe.
if value.__class__.__name__ in ("Future", "Deferred"):
return True
return False

try:
coro = async_fn(*args)
except TypeError:
# Give good error for: nursery.start_soon(trio.sleep(1))
if isinstance(async_fn, collections.abc.Coroutine):
raise TypeError(
"Trio was expecting an async function, but instead it got "
"a coroutine object {async_fn!r}\n"
"\n"
"Probably you did something like:\n"
"\n"
" trio.run({async_fn.__name__}(...)) # incorrect!\n"
" nursery.start_soon({async_fn.__name__}(...)) # incorrect!\n"
"\n"
"Instead, you want (notice the parentheses!):\n"
"\n"
" trio.run({async_fn.__name__}, ...) # correct!\n"
" nursery.start_soon({async_fn.__name__}, ...) # correct!"
.format(async_fn=async_fn)
) from None

# Give good error for: nursery.start_soon(future)
if _return_value_looks_like_wrong_library(async_fn):
raise TypeError(
"Trio was expecting an async function, but instead it got "
"{!r} – are you trying to use a library written for "
"asyncio/twisted/tornado or similar? That won't work "
"without some sort of compatibility shim."
.format(async_fn)
) from None

raise

# We can't check iscoroutinefunction(async_fn), because that will fail
# for things like functools.partial objects wrapping an async
# function. So we have to just call it and then check whether the
# return value is a coroutine object.
if not isinstance(coro, collections.abc.Coroutine):
# Give good error for: nursery.start_soon(func_returning_future)
if _return_value_looks_like_wrong_library(coro):
raise TypeError(
"start_soon got unexpected {!r} – are you trying to use a "
"library written for asyncio/twisted/tornado or similar? "
"That won't work without some sort of compatibility shim."
.format(coro)
)

if isasyncgen(coro):
raise TypeError(
"start_soon expected an async function but got an async "
"generator {!r}".format(coro)
)

# Give good error for: nursery.start_soon(some_sync_fn)
raise TypeError(
"Trio expected an async function, but {!r} appears to be "
"synchronous".format(
getattr(async_fn, "__qualname__", async_fn)
)
)

######
# Set up the Task object
guilledk marked this conversation as resolved.
Show resolved Hide resolved
######
coro = coroutine_or_error(async_fn, *args)

if name is None:
name = async_fn
Expand All @@ -1353,6 +1273,9 @@ async def python_wrapper(orig_coro):
LOCALS_KEY_KI_PROTECTION_ENABLED, system_task
)

######
# Set up the Task object
######
task = Task._create(
coro=coro,
parent_nursery=nursery,
Expand Down
87 changes: 17 additions & 70 deletions trio/_core/tests/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
import sniffio
import pytest

from .tutil import slow, check_sequence_matches, gc_collect_harder
from .tutil import (
slow, check_sequence_matches, gc_collect_harder,
ignore_coroutine_never_awaited_warnings
)

from ... import _core
from ..._threads import to_thread_run_sync
from ..._timeouts import sleep, fail_after
Expand All @@ -33,24 +37,6 @@ async def sleep_forever():
return await _core.wait_task_rescheduled(lambda _: _core.Abort.SUCCEEDED)


# Some of our tests need to leak coroutines, and thus trigger the
# "RuntimeWarning: coroutine '...' was never awaited" message. This context
# manager should be used anywhere this happens to hide those messages, because
# when expected they're clutter.
@contextmanager
def ignore_coroutine_never_awaited_warnings():
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", message="coroutine '.*' was never awaited"
)
try:
yield
finally:
# Make sure to trigger any coroutine __del__ methods now, before
# we leave the context manager.
gc_collect_harder()


def test_basic():
async def trivial(x):
return x
Expand Down Expand Up @@ -1696,8 +1682,6 @@ async def test_current_effective_deadline(mock_clock):
assert _core.current_effective_deadline() == inf


# @coroutine is deprecated since python 3.8, which is fine with us.
@pytest.mark.filterwarnings("ignore:.*@coroutine.*:DeprecationWarning")
def test_nice_error_on_bad_calls_to_run_or_spawn():
def bad_call_run(*args):
guilledk marked this conversation as resolved.
Show resolved Hide resolved
_core.run(*args)
Expand All @@ -1709,59 +1693,22 @@ async def main():

_core.run(main)

class Deferred:
"Just kidding"

with ignore_coroutine_never_awaited_warnings():
for bad_call in bad_call_run, bad_call_spawn:

async def f(): # pragma: no cover
pass

with pytest.raises(TypeError) as excinfo:
bad_call(f())
assert "expecting an async function" in str(excinfo.value)

import asyncio

@asyncio.coroutine
def generator_based_coro(): # pragma: no cover
yield from asyncio.sleep(1)

with pytest.raises(TypeError) as excinfo:
bad_call(generator_based_coro())
assert "asyncio" in str(excinfo.value)
for bad_call in bad_call_run, bad_call_spawn:

with pytest.raises(TypeError) as excinfo:
bad_call(asyncio.Future())
assert "asyncio" in str(excinfo.value)

with pytest.raises(TypeError) as excinfo:
bad_call(lambda: asyncio.Future())
assert "asyncio" in str(excinfo.value)

with pytest.raises(TypeError) as excinfo:
bad_call(Deferred())
assert "twisted" in str(excinfo.value)

with pytest.raises(TypeError) as excinfo:
bad_call(lambda: Deferred())
assert "twisted" in str(excinfo.value)

with pytest.raises(TypeError) as excinfo:
bad_call(len, [1, 2, 3])
assert "appears to be synchronous" in str(excinfo.value)
async def f(): # pragma: no cover
pass

async def async_gen(arg): # pragma: no cover
yield
with pytest.raises(TypeError, match="expecting an async function"):
bad_call(f())

with pytest.raises(TypeError) as excinfo:
bad_call(async_gen, 0)
msg = "expected an async function but got an async generator"
assert msg in str(excinfo.value)
async def async_gen(arg): # pragma: no cover
yield arg

# Make sure no references are kept around to keep anything alive
del excinfo
with pytest.raises(
TypeError,
match="expected an async function but got an async generator"
):
bad_call(async_gen, 0)


def test_calling_asyncio_function_gives_nice_error():
Expand Down
20 changes: 20 additions & 0 deletions trio/_core/tests/tutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import os

import pytest
import warnings
from contextlib import contextmanager

import gc

Expand Down Expand Up @@ -52,6 +54,24 @@ def gc_collect_harder():
gc.collect()


# Some of our tests need to leak coroutines, and thus trigger the
# "RuntimeWarning: coroutine '...' was never awaited" message. This context
# manager should be used anywhere this happens to hide those messages, because
# when expected they're clutter.
@contextmanager
def ignore_coroutine_never_awaited_warnings():
guilledk marked this conversation as resolved.
Show resolved Hide resolved
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", message="coroutine '.*' was never awaited"
)
try:
yield
finally:
# Make sure to trigger any coroutine __del__ methods now, before
# we leave the context manager.
gc_collect_harder()


# template is like:
# [1, {2.1, 2.2}, 3] -> matches [1, 2.1, 2.2, 3] or [1, 2.2, 2.1, 3]
def check_sequence_matches(seq, template):
Expand Down
22 changes: 17 additions & 5 deletions trio/_threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
from itertools import count

import attr
import inspect
import outcome

import trio

from ._sync import CapacityLimiter
from ._core import enable_ki_protection, disable_ki_protection, RunVar, TrioToken
from ._util import coroutine_or_error

# Global due to Threading API, thread local storage for trio token
TOKEN_LOCAL = threading.local()
Expand Down Expand Up @@ -365,6 +367,7 @@ def from_thread_run(afn, *args, trio_token=None):
which would otherwise cause a deadlock.
AttributeError: if no ``trio_token`` was provided, and we can't infer
one from context.
TypeError: if ``afn`` is not an asynchronous function.

**Locating a Trio Token**: There are two ways to specify which
`trio.run` loop to reenter:
Expand All @@ -380,7 +383,8 @@ def from_thread_run(afn, *args, trio_token=None):
def callback(q, afn, args):
@disable_ki_protection
async def unprotected_afn():
return await afn(*args)
coro = coroutine_or_error(afn, *args)
return await coro

async def await_in_trio_thread_task():
q.put_nowait(await outcome.acapture(unprotected_afn))
Expand All @@ -403,13 +407,11 @@ def from_thread_run_sync(fn, *args, trio_token=None):
Raises:
RunFinishedError: if the corresponding call to `trio.run` has
already completed.
Cancelled: if the corresponding call to `trio.run` completes
while ``afn(*args)`` is running, then ``afn`` is likely to raise
:exc:`trio.Cancelled`, and this will propagate out into
RuntimeError: if you try calling this from inside the Trio thread,
which would otherwise cause a deadlock.
AttributeError: if no ``trio_token`` was provided, and we can't infer
one from context.
TypeError: if ``fn`` is an async function.

**Locating a Trio Token**: There are two ways to specify which
`trio.run` loop to reenter:
Expand All @@ -425,7 +427,17 @@ def from_thread_run_sync(fn, *args, trio_token=None):
def callback(q, fn, args):
@disable_ki_protection
def unprotected_fn():
return fn(*args)
ret = fn(*args)

if inspect.iscoroutine(ret):
# Manually close coroutine to avoid RuntimeWarnings
ret.close()
raise TypeError(
"Trio expected a sync function, but {!r} appears to be "
"asynchronous".format(getattr(fn, "__qualname__", fn))
)

return ret

res = outcome.capture(unprotected_fn)
q.put_nowait(res)
Expand Down
Loading