Skip to content

Commit

Permalink
implement cancellation semantics suggestions from code review
Browse files Browse the repository at this point in the history
  • Loading branch information
richardsheridan committed Oct 8, 2023
1 parent 9f4e79e commit 0e18c93
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 8 deletions.
5 changes: 3 additions & 2 deletions trio/_tests/test_threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,6 +879,7 @@ def sync_check():
try:
from_thread_run_sync(bool)
except _core.Cancelled:

Check warning on line 881 in trio/_tests/test_threads.py

View check run for this annotation

Codecov / codecov/patch

trio/_tests/test_threads.py#L881

Added line #L881 was not covered by tests
# pragma: no cover, sync functions don't raise Cancelled
queue.put(True)

Check warning on line 883 in trio/_tests/test_threads.py

View check run for this annotation

Codecov / codecov/patch

trio/_tests/test_threads.py#L883

Added line #L883 was not covered by tests
else:
queue.put(False)
Expand All @@ -893,7 +894,7 @@ def sync_check():
await to_thread_run_sync(sync_check, cancellable=True)

assert cancel_scope.cancelled_caught
assert await to_thread_run_sync(partial(queue.get, timeout=1))
assert not await to_thread_run_sync(partial(queue.get, timeout=1))

async def no_checkpoint():
return True
Expand All @@ -917,7 +918,7 @@ def async_check():
await to_thread_run_sync(async_check, cancellable=True)

assert cancel_scope.cancelled_caught
assert await to_thread_run_sync(partial(queue.get, timeout=1))
assert not await to_thread_run_sync(partial(queue.get, timeout=1))

async def async_time_bomb():
cancel_scope.cancel()
Expand Down
18 changes: 12 additions & 6 deletions trio/_threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from trio._core._traps import RaiseCancelT

from ._core import (
CancelScope,
RunVar,
TrioToken,
disable_ki_protection,
Expand Down Expand Up @@ -86,6 +87,7 @@ class Run(Generic[RetT]):
queue: stdlib_queue.SimpleQueue[outcome.Outcome[RetT]] = attr.ib(
init=False, factory=stdlib_queue.SimpleQueue
)
scope: CancelScope = attr.ib(init=False, factory=CancelScope)

@disable_ki_protection
async def unprotected_afn(self) -> RetT:
Expand All @@ -106,7 +108,12 @@ async def run(self) -> None:
await trio.lowlevel.cancel_shielded_checkpoint()

async def run_system(self) -> None:
result = await outcome.acapture(self.unprotected_afn)
# NOTE: There is potential here to only conditionally enter a CancelScope
# when we need it, sparing some computation. But doing so adds substantial
# complexity, so we'll leave it until real need is demonstrated.
with self.scope:
result = await outcome.acapture(self.unprotected_afn)
assert not self.scope.cancelled_caught, "any Cancelled should go to our parent"
self.queue.put_nowait(result)


Expand Down Expand Up @@ -403,13 +410,14 @@ def _send_message_to_host_task(
message: Run[RetT] | RunSync[RetT], trio_token: TrioToken
) -> None:
task_register = PARENT_TASK_DATA.task_register
cancel_register = PARENT_TASK_DATA.cancel_register

def in_trio_thread() -> None:
task = task_register[0]
if task is None:
raise_cancel = cancel_register[0]
message.queue.put_nowait(outcome.capture(raise_cancel))
# Our parent task is gone! Punt to a system task.
if isinstance(message, Run):
message.scope.cancel()
_send_message_to_system_task(message, trio_token)
else:
trio.lowlevel.reschedule(task, outcome.Value(message))

Expand Down Expand Up @@ -509,8 +517,6 @@ def from_thread_run_sync(
Raises:
RunFinishedError: if the corresponding call to `trio.run` has
already completed.
Cancelled: if the corresponding `trio.to_thread.run_sync` task is
cancellable and exits before this function is called.
RuntimeError: if you try calling this from inside the Trio thread,
which would otherwise cause a deadlock or if no ``trio_token`` was
provided, and we can't infer one from context.
Expand Down

0 comments on commit 0e18c93

Please sign in to comment.