Skip to content

Commit

Permalink
Merge pull request #2392 from richardsheridan/from_thread_check_cance…
Browse files Browse the repository at this point in the history
…lled

Expand cancellation usability from native trio threads
  • Loading branch information
richardsheridan committed Oct 18, 2023
2 parents b161fec + ab092b0 commit b324b3a
Show file tree
Hide file tree
Showing 8 changed files with 441 additions and 130 deletions.
19 changes: 19 additions & 0 deletions docs/source/reference-core.rst
Expand Up @@ -1823,6 +1823,25 @@ to spawn a child thread, and then use a :ref:`memory channel

.. literalinclude:: reference-core/from-thread-example.py

.. note::

The ``from_thread.run*`` functions reuse the host task that called
:func:`trio.to_thread.run_sync` to run your provided function, as long as you're
using the default ``cancellable=False`` so Trio can be sure that the task will remain
around to perform the work. If you pass ``cancellable=True`` at the outset, or if
you provide a :class:`~trio.lowlevel.TrioToken` when calling back in to Trio, your
functions will be executed in a new system task. Therefore, the
:func:`~trio.lowlevel.current_task`, :func:`current_effective_deadline`, or other
task-tree specific values may differ depending on keyword argument values.

You can also use :func:`trio.from_thread.check_cancelled` to check for cancellation from
a thread that was spawned by :func:`trio.to_thread.run_sync`. If the call to
:func:`~trio.to_thread.run_sync` was cancelled (even if ``cancellable=False``!), then
:func:`~trio.from_thread.check_cancelled` will raise :func:`trio.Cancelled`.
It's like ``trio.from_thread.run(trio.sleep, 0)``, but much faster.

.. autofunction:: trio.from_thread.check_cancelled

Threads and task-local storage
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
5 changes: 5 additions & 0 deletions newsfragments/2392.feature.rst
@@ -0,0 +1,5 @@
If called from a thread spawned by `trio.to_thread.run_sync`, `trio.from_thread.run` and
`trio.from_thread.run_sync` now reuse the task and cancellation status of the host task;
this means that context variables and cancel scopes naturally propagate 'through'
threads spawned by Trio. You can also use `trio.from_thread.check_cancelled`
to efficiently check for cancellation without reentering the Trio thread.
203 changes: 188 additions & 15 deletions trio/_tests/test_threads.py
Expand Up @@ -13,13 +13,12 @@
import pytest
import sniffio

from trio._core import TrioToken, current_trio_token

from .. import CapacityLimiter, Event, _core, sleep
from .. import CapacityLimiter, Event, _core, fail_after, sleep, sleep_forever
from .._core._tests.test_ki import ki_self
from .._core._tests.tutil import buggy_pypy_asyncgens
from .._threads import (
current_default_thread_limiter,
from_thread_check_cancelled,
from_thread_run,
from_thread_run_sync,
to_thread_run_sync,
Expand Down Expand Up @@ -645,7 +644,7 @@ async def async_fn(): # pragma: no cover
def thread_fn():
from_thread_run_sync(async_fn)

with pytest.raises(TypeError, match="expected a sync function"):
with pytest.raises(TypeError, match="expected a synchronous function"):
await to_thread_run_sync(thread_fn)


Expand Down Expand Up @@ -810,25 +809,32 @@ def test_from_thread_run_during_shutdown():
save = []
record = []

async def agen():
async def agen(token):
try:
yield
finally:
with pytest.raises(_core.RunFinishedError), _core.CancelScope(shield=True):
await to_thread_run_sync(from_thread_run, sleep, 0)
record.append("ok")

async def main():
save.append(agen())
with _core.CancelScope(shield=True):
try:
await to_thread_run_sync(
partial(from_thread_run, sleep, 0, trio_token=token)
)
except _core.RunFinishedError:
record.append("finished")
else:
record.append("clean")

async def main(use_system_task):
save.append(agen(_core.current_trio_token() if use_system_task else None))
await save[-1].asend(None)

_core.run(main)
assert record == ["ok"]
_core.run(main, True) # System nursery will be closed and raise RunFinishedError
_core.run(main, False) # host task will be rescheduled as normal
assert record == ["finished", "clean"]


async def test_trio_token_weak_referenceable():
token = current_trio_token()
assert isinstance(token, TrioToken)
token = _core.current_trio_token()
assert isinstance(token, _core.TrioToken)
weak_reference = weakref.ref(token)
assert token is weak_reference()

Expand All @@ -842,3 +848,170 @@ def __bool__(self):

with pytest.raises(NotImplementedError):
await to_thread_run_sync(int, cancellable=BadBool())


async def test_from_thread_reuses_task():
task = _core.current_task()

async def async_current_task():
return _core.current_task()

assert task is await to_thread_run_sync(from_thread_run_sync, _core.current_task)
assert task is await to_thread_run_sync(from_thread_run, async_current_task)


async def test_recursive_to_thread():
tid = None

def get_tid_then_reenter():
nonlocal tid
tid = threading.get_ident()
return from_thread_run(to_thread_run_sync, threading.get_ident)

assert tid != await to_thread_run_sync(get_tid_then_reenter)


async def test_from_thread_host_cancelled():
queue = stdlib_queue.Queue()

def sync_check():
from_thread_run_sync(cancel_scope.cancel)
try:
from_thread_run_sync(bool)
except _core.Cancelled: # pragma: no cover
queue.put(True) # sync functions don't raise Cancelled
else:
queue.put(False)

with _core.CancelScope() as cancel_scope:
await to_thread_run_sync(sync_check)

assert not cancel_scope.cancelled_caught
assert not queue.get_nowait()

with _core.CancelScope() as cancel_scope:
await to_thread_run_sync(sync_check, cancellable=True)

assert cancel_scope.cancelled_caught
assert not await to_thread_run_sync(partial(queue.get, timeout=1))

async def no_checkpoint():
return True

def async_check():
from_thread_run_sync(cancel_scope.cancel)
try:
assert from_thread_run(no_checkpoint)
except _core.Cancelled: # pragma: no cover
queue.put(True) # async functions raise Cancelled at checkpoints
else:
queue.put(False)

with _core.CancelScope() as cancel_scope:
await to_thread_run_sync(async_check)

assert not cancel_scope.cancelled_caught
assert not queue.get_nowait()

with _core.CancelScope() as cancel_scope:
await to_thread_run_sync(async_check, cancellable=True)

assert cancel_scope.cancelled_caught
assert not await to_thread_run_sync(partial(queue.get, timeout=1))

async def async_time_bomb():
cancel_scope.cancel()
with fail_after(10):
await sleep_forever()

with _core.CancelScope() as cancel_scope:
await to_thread_run_sync(from_thread_run, async_time_bomb)

assert cancel_scope.cancelled_caught


async def test_from_thread_check_cancelled():
q = stdlib_queue.Queue()

async def child(cancellable, scope):
with scope:
record.append("start")
try:
return await to_thread_run_sync(f, cancellable=cancellable)
except _core.Cancelled:
record.append("cancel")
raise
finally:
record.append("exit")

def f():
try:
from_thread_check_cancelled()
except _core.Cancelled: # pragma: no cover, test failure path
q.put("Cancelled")
else:
q.put("Not Cancelled")
ev.wait()
return from_thread_check_cancelled()

# Base case: nothing cancelled so we shouldn't see cancels anywhere
record = []
ev = threading.Event()
async with _core.open_nursery() as nursery:
nursery.start_soon(child, False, _core.CancelScope())
await wait_all_tasks_blocked()
assert record[0] == "start"
assert q.get(timeout=1) == "Not Cancelled"
ev.set()
# implicit assertion, Cancelled not raised via nursery
assert record[1] == "exit"

# cancellable=False case: a cancel will pop out but be handled by
# the appropriate cancel scope
record = []
ev = threading.Event()
scope = _core.CancelScope() # Nursery cancel scope gives false positives
async with _core.open_nursery() as nursery:
nursery.start_soon(child, False, scope)
await wait_all_tasks_blocked()
assert record[0] == "start"
assert q.get(timeout=1) == "Not Cancelled"
scope.cancel()
ev.set()
assert scope.cancelled_caught
assert "cancel" in record
assert record[-1] == "exit"

# cancellable=True case: slightly different thread behavior needed
# check thread is cancelled "soon" after abandonment
def f(): # noqa: F811
ev.wait()
try:
from_thread_check_cancelled()
except _core.Cancelled:
q.put("Cancelled")
else: # pragma: no cover, test failure path
q.put("Not Cancelled")

record = []
ev = threading.Event()
scope = _core.CancelScope()
async with _core.open_nursery() as nursery:
nursery.start_soon(child, True, scope)
await wait_all_tasks_blocked()
assert record[0] == "start"
scope.cancel()
ev.set()
assert scope.cancelled_caught
assert "cancel" in record
assert record[-1] == "exit"
assert q.get(timeout=1) == "Cancelled"


async def test_from_thread_check_cancelled_raises_in_foreign_threads():
with pytest.raises(RuntimeError):
from_thread_check_cancelled()
q = stdlib_queue.Queue()
_core.start_thread_soon(from_thread_check_cancelled, lambda _: q.put(_))
with pytest.raises(RuntimeError):
q.get(timeout=1).unwrap()
2 changes: 1 addition & 1 deletion trio/_tests/verify_types_darwin.json
Expand Up @@ -40,7 +40,7 @@
],
"exportedSymbolCounts": {
"withAmbiguousType": 0,
"withKnownType": 630,
"withKnownType": 631,
"withUnknownType": 0
},
"ignoreUnknownTypesFromImports": true,
Expand Down
2 changes: 1 addition & 1 deletion trio/_tests/verify_types_linux.json
Expand Up @@ -28,7 +28,7 @@
],
"exportedSymbolCounts": {
"withAmbiguousType": 0,
"withKnownType": 627,
"withKnownType": 628,
"withUnknownType": 0
},
"ignoreUnknownTypesFromImports": true,
Expand Down
2 changes: 1 addition & 1 deletion trio/_tests/verify_types_windows.json
Expand Up @@ -64,7 +64,7 @@
],
"exportedSymbolCounts": {
"withAmbiguousType": 0,
"withKnownType": 630,
"withKnownType": 631,
"withUnknownType": 0
},
"ignoreUnknownTypesFromImports": true,
Expand Down

0 comments on commit b324b3a

Please sign in to comment.