Skip to content

Commit

Permalink
Add types to _signals, test_signals (#2794)
Browse files Browse the repository at this point in the history
* Add types to _signals and test_signals

* Hide SignalReceiver from the API.

Convert this helper to a function to allow that.

* Enable strict type checking for _signals and test_signals

* Tweak comment wording

---------

Co-authored-by: CoolCat467 <52022020+CoolCat467@users.noreply.github.com>
  • Loading branch information
TeamSpen210 and CoolCat467 committed Sep 13, 2023
1 parent 2d8a8d9 commit c16003f
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 38 deletions.
3 changes: 0 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,6 @@ module = [
"trio/_core/_generated_io_windows",
"trio/_core/_io_windows",

"trio/_signals",

# internal
"trio/_windows_pipes",

Expand Down Expand Up @@ -93,7 +91,6 @@ module = [
"trio/_tests/test_highlevel_ssl_helpers",
"trio/_tests/test_path",
"trio/_tests/test_scheduler_determinism",
"trio/_tests/test_signals",
"trio/_tests/test_socket",
"trio/_tests/test_ssl",
"trio/_tests/test_subprocess",
Expand Down
44 changes: 30 additions & 14 deletions trio/_signals.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
from __future__ import annotations

import signal
from collections import OrderedDict
from collections.abc import AsyncIterator, Callable, Generator, Iterable
from contextlib import contextmanager
from types import FrameType
from typing import TYPE_CHECKING

import trio

from ._util import ConflictDetector, is_main_thread, signal_raise

if TYPE_CHECKING:
from typing_extensions import Self

# Discussion of signal handling strategies:
#
# - On Windows signals barely exist. There are no options; signal handlers are
Expand Down Expand Up @@ -43,7 +51,10 @@


@contextmanager
def _signal_handler(signals, handler):
def _signal_handler(
signals: Iterable[int],
handler: Callable[[int, FrameType | None], object] | int | signal.Handlers | None,
) -> Generator[None, None, None]:
original_handlers = {}
try:
for signum in set(signals):
Expand All @@ -55,31 +66,31 @@ def _signal_handler(signals, handler):


class SignalReceiver:
def __init__(self):
def __init__(self) -> None:
# {signal num: None}
self._pending = OrderedDict()
self._pending: OrderedDict[int, None] = OrderedDict()
self._lot = trio.lowlevel.ParkingLot()
self._conflict_detector = ConflictDetector(
"only one task can iterate on a signal receiver at a time"
)
self._closed = False

def _add(self, signum):
def _add(self, signum: int) -> None:
if self._closed:
signal_raise(signum)
else:
self._pending[signum] = None
self._lot.unpark()

def _redeliver_remaining(self):
def _redeliver_remaining(self) -> None:
# First make sure that any signals still in the delivery pipeline will
# get redelivered
self._closed = True

# And then redeliver any that are sitting in pending. This is done
# using a weird recursive construct to make sure we process everything
# even if some of the handlers raise exceptions.
def deliver_next():
def deliver_next() -> None:
if self._pending:
signum, _ = self._pending.popitem(last=False)
try:
Expand All @@ -89,14 +100,10 @@ def deliver_next():

deliver_next()

# Helper for tests, not public or otherwise used
def _pending_signal_count(self):
return len(self._pending)

def __aiter__(self):
def __aiter__(self) -> Self:
return self

async def __anext__(self):
async def __anext__(self) -> int:
if self._closed:
raise RuntimeError("open_signal_receiver block already exited")
# In principle it would be possible to support multiple concurrent
Expand All @@ -111,8 +118,17 @@ async def __anext__(self):
return signum


def get_pending_signal_count(rec: AsyncIterator[int]) -> int:
"""Helper for tests, not public or otherwise used."""
# open_signal_receiver() always produces SignalReceiver, this should not fail.
assert isinstance(rec, SignalReceiver)
return len(rec._pending)


@contextmanager
def open_signal_receiver(*signals):
def open_signal_receiver(
*signals: signal.Signals | int,
) -> Generator[AsyncIterator[int], None, None]:
"""A context manager for catching signals.
Entering this context manager starts listening for the given signals and
Expand Down Expand Up @@ -158,7 +174,7 @@ def open_signal_receiver(*signals):
token = trio.lowlevel.current_trio_token()
queue = SignalReceiver()

def handler(signum, _):
def handler(signum: int, frame: FrameType | None) -> None:
token.run_sync_soon(queue._add, signum, idempotent=True)

try:
Expand Down
46 changes: 25 additions & 21 deletions trio/_tests/test_signals.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
from __future__ import annotations

import signal
from types import FrameType
from typing import NoReturn

import pytest

import trio

from .. import _core
from .._signals import _signal_handler, open_signal_receiver
from .._signals import _signal_handler, get_pending_signal_count, open_signal_receiver
from .._util import signal_raise


async def test_open_signal_receiver():
async def test_open_signal_receiver() -> None:
orig = signal.getsignal(signal.SIGILL)
with open_signal_receiver(signal.SIGILL) as receiver:
# Raise it a few times, to exercise signal coalescing, both at the
Expand All @@ -22,18 +26,18 @@ async def test_open_signal_receiver():
async for signum in receiver: # pragma: no branch
assert signum == signal.SIGILL
break
assert receiver._pending_signal_count() == 0
assert get_pending_signal_count(receiver) == 0
signal_raise(signal.SIGILL)
async for signum in receiver: # pragma: no branch
assert signum == signal.SIGILL
break
assert receiver._pending_signal_count() == 0
assert get_pending_signal_count(receiver) == 0
with pytest.raises(RuntimeError):
await receiver.__anext__()
assert signal.getsignal(signal.SIGILL) is orig


async def test_open_signal_receiver_restore_handler_after_one_bad_signal():
async def test_open_signal_receiver_restore_handler_after_one_bad_signal() -> None:
orig = signal.getsignal(signal.SIGILL)
with pytest.raises(ValueError):
with open_signal_receiver(signal.SIGILL, 1234567):
Expand All @@ -42,30 +46,30 @@ async def test_open_signal_receiver_restore_handler_after_one_bad_signal():
assert signal.getsignal(signal.SIGILL) is orig


async def test_open_signal_receiver_empty_fail():
async def test_open_signal_receiver_empty_fail() -> None:
with pytest.raises(TypeError, match="No signals were provided"):
with open_signal_receiver():
pass


async def test_open_signal_receiver_restore_handler_after_duplicate_signal():
async def test_open_signal_receiver_restore_handler_after_duplicate_signal() -> None:
orig = signal.getsignal(signal.SIGILL)
with open_signal_receiver(signal.SIGILL, signal.SIGILL):
pass
# Still restored correctly
assert signal.getsignal(signal.SIGILL) is orig


async def test_catch_signals_wrong_thread():
async def naughty():
async def test_catch_signals_wrong_thread() -> None:
async def naughty() -> None:
with open_signal_receiver(signal.SIGINT):
pass # pragma: no cover

with pytest.raises(RuntimeError):
await trio.to_thread.run_sync(trio.run, naughty)


async def test_open_signal_receiver_conflict():
async def test_open_signal_receiver_conflict() -> None:
with pytest.raises(trio.BusyResourceError):
with open_signal_receiver(signal.SIGILL) as receiver:
async with trio.open_nursery() as nursery:
Expand All @@ -75,14 +79,14 @@ async def test_open_signal_receiver_conflict():

# Blocks until all previous calls to run_sync_soon(idempotent=True) have been
# processed.
async def wait_run_sync_soon_idempotent_queue_barrier():
async def wait_run_sync_soon_idempotent_queue_barrier() -> None:
ev = trio.Event()
token = _core.current_trio_token()
token.run_sync_soon(ev.set, idempotent=True)
await ev.wait()


async def test_open_signal_receiver_no_starvation():
async def test_open_signal_receiver_no_starvation() -> None:
# Set up a situation where there are always 2 pending signals available to
# report, and make sure that instead of getting the same signal reported
# over and over, it alternates between reporting both of them.
Expand All @@ -101,8 +105,8 @@ async def test_open_signal_receiver_no_starvation():
assert got in [signal.SIGILL, signal.SIGFPE]
assert got != previous
previous = got
# Clear out the last signal so it doesn't get redelivered
while receiver._pending_signal_count() != 0:
# Clear out the last signal so that it doesn't get redelivered
while get_pending_signal_count(receiver) != 0:
await receiver.__anext__()
except: # pragma: no cover
# If there's an unhandled exception above, then exiting the
Expand All @@ -113,10 +117,10 @@ async def test_open_signal_receiver_no_starvation():
traceback.print_exc()


async def test_catch_signals_race_condition_on_exit():
delivered_directly = set()
async def test_catch_signals_race_condition_on_exit() -> None:
delivered_directly: set[int] = set()

def direct_handler(signo, frame):
def direct_handler(signo: int, frame: FrameType | None) -> None:
delivered_directly.add(signo)

print(1)
Expand All @@ -138,7 +142,7 @@ def direct_handler(signo, frame):
signal_raise(signal.SIGILL)
signal_raise(signal.SIGFPE)
await wait_run_sync_soon_idempotent_queue_barrier()
assert receiver._pending_signal_count() == 2
assert get_pending_signal_count(receiver) == 2
assert delivered_directly == {signal.SIGILL, signal.SIGFPE}
delivered_directly.clear()

Expand All @@ -156,12 +160,12 @@ def direct_handler(signo, frame):
with open_signal_receiver(signal.SIGILL) as receiver:
signal_raise(signal.SIGILL)
await wait_run_sync_soon_idempotent_queue_barrier()
assert receiver._pending_signal_count() == 1
assert get_pending_signal_count(receiver) == 1
# test passes if the process reaches this point without dying

# Check exception chaining if there are multiple exception-raising
# handlers
def raise_handler(signum, _):
def raise_handler(signum: int, frame: FrameType | None) -> NoReturn:
raise RuntimeError(signum)

with _signal_handler({signal.SIGILL, signal.SIGFPE}, raise_handler):
Expand All @@ -170,7 +174,7 @@ def raise_handler(signum, _):
signal_raise(signal.SIGILL)
signal_raise(signal.SIGFPE)
await wait_run_sync_soon_idempotent_queue_barrier()
assert receiver._pending_signal_count() == 2
assert get_pending_signal_count(receiver) == 2
exc = excinfo.value
signums = {exc.args[0]}
assert isinstance(exc.__context__, RuntimeError)
Expand Down

0 comments on commit c16003f

Please sign in to comment.