diff --git a/newsfragments/2191.bugfix.rst b/newsfragments/2191.bugfix.rst new file mode 100644 index 0000000000..59f3a198c4 --- /dev/null +++ b/newsfragments/2191.bugfix.rst @@ -0,0 +1 @@ +`trio.from_thread.run` and `trio.from_thread.run_sync` no longer raise `RuntimeError` when used to enter a thread running Trio from a different thread running Trio. diff --git a/trio/_threads.py b/trio/_threads.py index 15435d44aa..056c242205 100644 --- a/trio/_threads.py +++ b/trio/_threads.py @@ -245,13 +245,18 @@ def _run_fn_as_system_task(cb, fn, *args, context, trio_token=None): "this thread wasn't created by Trio, pass kwarg trio_token=..." ) - # Avoid deadlock by making sure we're not called from Trio thread + # Avoid deadlock by making sure we're not called from the same Trio thread we're + # trying to enter try: - trio.lowlevel.current_task() + current_trio_token = trio.lowlevel.current_trio_token() except RuntimeError: pass else: - raise RuntimeError("this is a blocking function; call it from a thread") + if trio_token == current_trio_token: + raise RuntimeError( + "this is a blocking function; using it to re-enter the current Trio " + "thread would cause a deadlock" + ) q = stdlib_queue.SimpleQueue() trio_token.run_sync_soon(context.run, cb, q, fn, args) diff --git a/trio/tests/test_threads.py b/trio/tests/test_threads.py index ce852d4612..4e15ee9805 100644 --- a/trio/tests/test_threads.py +++ b/trio/tests/test_threads.py @@ -818,7 +818,7 @@ def test_run_fn_as_system_task_catched_badly_typed_token(): from_thread_run_sync(_core.current_time, trio_token="Not TrioTokentype") -async def test_from_thread_inside_trio_thread(): +async def test_from_thread_inside_same_trio_thread(): def not_called(): # pragma: no cover assert False @@ -827,6 +827,15 @@ def not_called(): # pragma: no cover from_thread_run_sync(not_called, trio_token=trio_token) +async def test_from_thread_inside_different_trio_thread(): + target_token = current_trio_token() + + async def thread_fn(): + return from_thread_run_sync(current_trio_token, trio_token=target_token) + + assert target_token == await to_thread_run_sync(_core.run, thread_fn) + + @pytest.mark.skipif(buggy_pypy_asyncgens, reason="pypy 7.2.0 is buggy") def test_from_thread_run_during_shutdown(): save = []