diff --git a/check.sh b/check.sh index badad99127..3e07056dcd 100755 --- a/check.sh +++ b/check.sh @@ -110,6 +110,7 @@ if [ $PYRIGHT -ne 0 ]; then fi pyright src/trio/_tests/type_tests || EXIT_STATUS=$? +pyright src/trio/_core/_tests/type_tests || EXIT_STATUS=$? echo "::endgroup::" # Finally, leave a really clear warning of any issues and exit diff --git a/docs/source/conf.py b/docs/source/conf.py index af358abf15..1e508997ca 100755 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -73,7 +73,6 @@ # aliasing doesn't actually fix the warning for types.FrameType, but displaying # "types.FrameType" is more helpful than just "frame" "FrameType": "types.FrameType", - "Context": "OpenSSL.SSL.Context", # SSLListener.accept's return type is seen as trio._ssl.SSLStream "SSLStream": "trio.SSLStream", } @@ -91,6 +90,8 @@ def autodoc_process_signature( # Strip the type from the union, make it look like = ... signature = signature.replace(" | type[trio._core._local._NoValue]", "") signature = signature.replace("", "...") + if "DTLS" in name: + signature = signature.replace("SSL.Context", "OpenSSL.SSL.Context") # Don't specify PathLike[str] | PathLike[bytes], this is just for humans. signature = signature.replace("StrOrBytesPath", "str | bytes | os.PathLike") diff --git a/newsfragments/2881.feature.rst b/newsfragments/2881.feature.rst new file mode 100644 index 0000000000..4e8efe47f3 --- /dev/null +++ b/newsfragments/2881.feature.rst @@ -0,0 +1 @@ +`TypeVarTuple `_ is now used to fully type :meth:`nursery.start_soon() `, :func:`trio.run()`, :func:`trio.to_thread.run_sync()`, and other similar functions accepting ``(func, *args)``. This means type checkers will be able to verify types are used correctly. :meth:`nursery.start() ` is not fully typed yet however. diff --git a/pyproject.toml b/pyproject.toml index cb24dd9eb8..b4733d7de9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -157,6 +157,8 @@ check_untyped_defs = true [tool.pyright] pythonVersion = "3.8" +reportUnnecessaryTypeIgnoreComment = true +typeCheckingMode = "strict" [tool.pytest.ini_options] addopts = ["--strict-markers", "--strict-config", "-p trio._tests.pytest_plugin"] @@ -235,6 +237,8 @@ omit = [ "*/trio/_core/_tests/test_multierror_scripts/*", # Omit the generated files in trio/_core starting with _generated_ "*/trio/_core/_generated_*", + # Type tests aren't intended to be run, just passed to type checkers. + "*/type_tests/*", ] # The test suite spawns subprocesses to test some stuff, so make sure # this doesn't corrupt the coverage files diff --git a/src/trio/_core/_entry_queue.py b/src/trio/_core/_entry_queue.py index cb91025fbb..582441e7d8 100644 --- a/src/trio/_core/_entry_queue.py +++ b/src/trio/_core/_entry_queue.py @@ -2,7 +2,7 @@ import threading from collections import deque -from typing import Callable, Iterable, NoReturn, Tuple +from typing import TYPE_CHECKING, Callable, NoReturn, Tuple import attr @@ -10,10 +10,13 @@ from .._util import NoPublicConstructor, final from ._wakeup_socketpair import WakeupSocketpair -# TODO: Type with TypeVarTuple, at least to an extent where it makes -# the public interface safe. +if TYPE_CHECKING: + from typing_extensions import TypeVarTuple, Unpack + + PosArgsT = TypeVarTuple("PosArgsT") + Function = Callable[..., object] -Job = Tuple[Function, Iterable[object]] +Job = Tuple[Function, Tuple[object, ...]] @attr.s(slots=True) @@ -122,7 +125,10 @@ def size(self) -> int: return len(self.queue) + len(self.idempotent_queue) def run_sync_soon( - self, sync_fn: Function, *args: object, idempotent: bool = False + self, + sync_fn: Callable[[Unpack[PosArgsT]], object], + *args: Unpack[PosArgsT], + idempotent: bool = False, ) -> None: with self.lock: if self.done: @@ -163,7 +169,10 @@ class TrioToken(metaclass=NoPublicConstructor): _reentry_queue: EntryQueue = attr.ib() def run_sync_soon( - self, sync_fn: Function, *args: object, idempotent: bool = False + self, + sync_fn: Callable[[Unpack[PosArgsT]], object], + *args: Unpack[PosArgsT], + idempotent: bool = False, ) -> None: """Schedule a call to ``sync_fn(*args)`` to occur in the context of a Trio task. diff --git a/src/trio/_core/_generated_instrumentation.py b/src/trio/_core/_generated_instrumentation.py index debd1e7bb5..e9c7250f6e 100644 --- a/src/trio/_core/_generated_instrumentation.py +++ b/src/trio/_core/_generated_instrumentation.py @@ -11,6 +11,8 @@ if TYPE_CHECKING: from ._instrumentation import Instrument +__all__ = ["add_instrument", "remove_instrument"] + def add_instrument(instrument: Instrument) -> None: """Start instrumenting the current run loop with the given instrument. diff --git a/src/trio/_core/_generated_io_epoll.py b/src/trio/_core/_generated_io_epoll.py index d2547e7619..704a67d557 100644 --- a/src/trio/_core/_generated_io_epoll.py +++ b/src/trio/_core/_generated_io_epoll.py @@ -15,6 +15,9 @@ assert not TYPE_CHECKING or sys.platform == "linux" +__all__ = ["notify_closing", "wait_readable", "wait_writable"] + + async def wait_readable(fd: (int | _HasFileNo)) -> None: """Block until the kernel reports that the given object is readable. diff --git a/src/trio/_core/_generated_io_kqueue.py b/src/trio/_core/_generated_io_kqueue.py index 18467f0447..39662fd902 100644 --- a/src/trio/_core/_generated_io_kqueue.py +++ b/src/trio/_core/_generated_io_kqueue.py @@ -19,6 +19,16 @@ assert not TYPE_CHECKING or sys.platform == "darwin" +__all__ = [ + "current_kqueue", + "monitor_kevent", + "notify_closing", + "wait_kevent", + "wait_readable", + "wait_writable", +] + + def current_kqueue() -> select.kqueue: """TODO: these are implemented, but are currently more of a sketch than anything real. See `#26 diff --git a/src/trio/_core/_generated_io_windows.py b/src/trio/_core/_generated_io_windows.py index b705a77267..bb23e630c2 100644 --- a/src/trio/_core/_generated_io_windows.py +++ b/src/trio/_core/_generated_io_windows.py @@ -19,6 +19,19 @@ assert not TYPE_CHECKING or sys.platform == "win32" +__all__ = [ + "current_iocp", + "monitor_completion_key", + "notify_closing", + "readinto_overlapped", + "register_with_iocp", + "wait_overlapped", + "wait_readable", + "wait_writable", + "write_overlapped", +] + + async def wait_readable(sock: (_HasFileNo | int)) -> None: """Block until the kernel reports that the given object is readable. diff --git a/src/trio/_core/_generated_run.py b/src/trio/_core/_generated_run.py index 1ec415497c..30b6e7d1c7 100644 --- a/src/trio/_core/_generated_run.py +++ b/src/trio/_core/_generated_run.py @@ -13,9 +13,23 @@ from collections.abc import Awaitable, Callable from outcome import Outcome + from typing_extensions import Unpack from .._abc import Clock from ._entry_queue import TrioToken + from ._run import PosArgT + + +__all__ = [ + "current_clock", + "current_root_task", + "current_statistics", + "current_time", + "current_trio_token", + "reschedule", + "spawn_system_task", + "wait_all_tasks_blocked", +] def current_statistics() -> RunStatistics: @@ -113,8 +127,8 @@ def reschedule(task: Task, next_send: Outcome[Any] = _NO_SEND) -> None: def spawn_system_task( - async_fn: Callable[..., Awaitable[object]], - *args: object, + async_fn: Callable[[Unpack[PosArgT]], Awaitable[object]], + *args: Unpack[PosArgT], name: object = None, context: (contextvars.Context | None) = None, ) -> Task: diff --git a/src/trio/_core/_run.py b/src/trio/_core/_run.py index f6beed46c5..015065863e 100644 --- a/src/trio/_core/_run.py +++ b/src/trio/_core/_run.py @@ -53,6 +53,12 @@ if sys.version_info < (3, 11): from exceptiongroup import BaseExceptionGroup +FnT = TypeVar("FnT", bound="Callable[..., Any]") +StatusT = TypeVar("StatusT") +StatusT_co = TypeVar("StatusT_co", covariant=True) +StatusT_contra = TypeVar("StatusT_contra", contravariant=True) +RetT = TypeVar("RetT") + if TYPE_CHECKING: import contextvars @@ -70,19 +76,25 @@ # for some strange reason Sphinx works with outcome.Outcome, but not Outcome, in # start_guest_run. Same with types.FrameType in iter_await_frames import outcome - from typing_extensions import Self + from typing_extensions import Self, TypeVarTuple, Unpack + + PosArgT = TypeVarTuple("PosArgT") + + # Needs to be guarded, since Unpack[] would be evaluated at runtime. + class _NurseryStartFunc(Protocol[Unpack[PosArgT], StatusT_co]): + """Type of functions passed to `nursery.start() `.""" + + def __call__( + self, *args: Unpack[PosArgT], task_status: TaskStatus[StatusT_co] + ) -> Awaitable[object]: + ... + DEADLINE_HEAP_MIN_PRUNE_THRESHOLD: Final = 1000 # Passed as a sentinel _NO_SEND: Final[Outcome[Any]] = cast("Outcome[Any]", object()) -FnT = TypeVar("FnT", bound="Callable[..., Any]") -StatusT = TypeVar("StatusT") -StatusT_co = TypeVar("StatusT_co", covariant=True) -StatusT_contra = TypeVar("StatusT_contra", contravariant=True) -RetT = TypeVar("RetT") - @final class _NoStatus(metaclass=NoPublicConstructor): @@ -1119,9 +1131,8 @@ def aborted(raise_cancel: _core.RaiseCancelT) -> Abort: def start_soon( self, - # TODO: TypeVarTuple - async_fn: Callable[..., Awaitable[object]], - *args: object, + async_fn: Callable[[Unpack[PosArgT]], Awaitable[object]], + *args: Unpack[PosArgT], name: object = None, ) -> None: """Creates a child task, scheduling ``await async_fn(*args)``. @@ -1170,7 +1181,7 @@ async def start( async_fn: Callable[..., Awaitable[object]], *args: object, name: object = None, - ) -> StatusT: + ) -> Any: r"""Creates and initializes a child task. Like :meth:`start_soon`, but blocks until the new task has @@ -1219,7 +1230,7 @@ async def async_fn(arg1, arg2, *, task_status=trio.TASK_STATUS_IGNORED): # `run` option, which would cause it to wrap a pre-started() # exception in an extra ExceptionGroup. See #2611. async with open_nursery(strict_exception_groups=False) as old_nursery: - task_status: _TaskStatus[StatusT] = _TaskStatus(old_nursery, self) + task_status: _TaskStatus[Any] = _TaskStatus(old_nursery, self) thunk = functools.partial(async_fn, task_status=task_status) task = GLOBAL_RUN_CONTEXT.runner.spawn_impl( thunk, args, old_nursery, name @@ -1232,7 +1243,7 @@ async def async_fn(arg1, arg2, *, task_status=trio.TASK_STATUS_IGNORED): # (Any exceptions propagate directly out of the above.) if task_status._value is _NoStatus: raise RuntimeError("child exited without calling task_status.started()") - return task_status._value # type: ignore[return-value] # Mypy doesn't narrow yet. + return task_status._value finally: self._pending_starts -= 1 self._check_nursery_closed() @@ -1690,9 +1701,8 @@ def reschedule( # type: ignore[misc] def spawn_impl( self, - # TODO: TypeVarTuple - async_fn: Callable[..., Awaitable[object]], - args: tuple[object, ...], + async_fn: Callable[[Unpack[PosArgT]], Awaitable[object]], + args: tuple[Unpack[PosArgT]], nursery: Nursery | None, name: object, *, @@ -1721,7 +1731,8 @@ def spawn_impl( # Call the function and get the coroutine object, while giving helpful # errors for common mistakes. ###### - coro = context.run(coroutine_or_error, async_fn, *args) + # TypeVarTuple passed into ParamSpec function confuses Mypy. + coro = context.run(coroutine_or_error, async_fn, *args) # type: ignore[arg-type] if name is None: name = async_fn @@ -1808,12 +1819,11 @@ def task_exited(self, task: Task, outcome: Outcome[Any]) -> None: # System tasks and init ################ - @_public # Type-ignore due to use of Any here. - def spawn_system_task( # type: ignore[misc] + @_public + def spawn_system_task( self, - # TODO: TypeVarTuple - async_fn: Callable[..., Awaitable[object]], - *args: object, + async_fn: Callable[[Unpack[PosArgT]], Awaitable[object]], + *args: Unpack[PosArgT], name: object = None, context: contextvars.Context | None = None, ) -> Task: @@ -1878,10 +1888,9 @@ def spawn_system_task( # type: ignore[misc] ) async def init( - # TODO: TypeVarTuple self, - async_fn: Callable[..., Awaitable[object]], - args: tuple[object, ...], + async_fn: Callable[[Unpack[PosArgT]], Awaitable[object]], + args: tuple[Unpack[PosArgT]], ) -> None: # run_sync_soon task runs here: async with open_nursery() as run_sync_soon_nursery: @@ -2407,8 +2416,8 @@ def my_done_callback(run_outcome): # straight through. def unrolled_run( runner: Runner, - async_fn: Callable[..., object], - args: tuple[object, ...], + async_fn: Callable[[Unpack[PosArgT]], Awaitable[object]], + args: tuple[Unpack[PosArgT]], host_uses_signal_set_wakeup_fd: bool = False, ) -> Generator[float, EventResult, None]: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True diff --git a/src/trio/_core/_tests/test_guest_mode.py b/src/trio/_core/_tests/test_guest_mode.py index ddc7b480f8..aa912ab70e 100644 --- a/src/trio/_core/_tests/test_guest_mode.py +++ b/src/trio/_core/_tests/test_guest_mode.py @@ -658,8 +658,6 @@ async def trio_main() -> None: # Ensure we don't pollute the thread-level context if run under # an asyncio without contextvars support (3.6) context = contextvars.copy_context() - if TYPE_CHECKING: - aiotrio_run(trio_main, host_uses_signal_set_wakeup_fd=True) context.run(aiotrio_run, trio_main, host_uses_signal_set_wakeup_fd=True) assert record == {("asyncio", "asyncio"), ("trio", "trio")} diff --git a/src/trio/_core/_tests/test_local.py b/src/trio/_core/_tests/test_local.py index 5fdf54b13c..17b10bca35 100644 --- a/src/trio/_core/_tests/test_local.py +++ b/src/trio/_core/_tests/test_local.py @@ -77,8 +77,8 @@ async def task1() -> None: t1.set("plaice") assert t1.get() == "plaice" - async def task2(tok: str) -> None: - t1.reset(token) + async def task2(tok: RunVarToken[str]) -> None: + t1.reset(tok) with pytest.raises(LookupError): t1.get() diff --git a/src/trio/_core/_tests/test_run.py b/src/trio/_core/_tests/test_run.py index 5bd98f0e91..212640e78f 100644 --- a/src/trio/_core/_tests/test_run.py +++ b/src/trio/_core/_tests/test_run.py @@ -1554,6 +1554,7 @@ async def child2() -> None: assert tasks["child2"].child_nurseries == [] async def child1( + *, task_status: _core.TaskStatus[None] = _core.TASK_STATUS_IGNORED, ) -> None: me = tasks["child1"] = _core.current_task() @@ -1774,6 +1775,7 @@ async def sleep_then_start( # calling started twice async def double_started( + *, task_status: _core.TaskStatus[None] = _core.TASK_STATUS_IGNORED, ) -> None: task_status.started() @@ -1785,6 +1787,7 @@ async def double_started( # child crashes before calling started -> error comes out of .start() async def raise_keyerror( + *, task_status: _core.TaskStatus[None] = _core.TASK_STATUS_IGNORED, ) -> None: raise KeyError("oops") @@ -1795,6 +1798,7 @@ async def raise_keyerror( # child exiting cleanly before calling started -> triggers a RuntimeError async def nothing( + *, task_status: _core.TaskStatus[None] = _core.TASK_STATUS_IGNORED, ) -> None: return @@ -1808,6 +1812,7 @@ async def nothing( # nothing -- the child keeps executing under start(). The value it passed # is ignored; start() raises Cancelled. async def just_started( + *, task_status: _core.TaskStatus[str] = _core.TASK_STATUS_IGNORED, ) -> None: task_status.started("hi") @@ -1989,7 +1994,7 @@ def __init__(self, *largs: it) -> None: self.nexts = [obj.__anext__ for obj in largs] async def _accumulate( - self, f: Callable[[], Awaitable[int]], items: list[int | None], i: int + self, f: Callable[[], Awaitable[int]], items: list[int], i: int ) -> None: items[i] = await f() diff --git a/src/trio/_core/_tests/type_tests/nursery_start.py b/src/trio/_core/_tests/type_tests/nursery_start.py new file mode 100644 index 0000000000..c4a2915bf0 --- /dev/null +++ b/src/trio/_core/_tests/type_tests/nursery_start.py @@ -0,0 +1,85 @@ +"""Test variadic generic typing for Nursery.start[_soon]().""" +from typing import Awaitable, Callable + +from trio import TASK_STATUS_IGNORED, Nursery, TaskStatus + + +async def task_0() -> None: + ... + + +async def task_1a(value: int) -> None: + ... + + +async def task_1b(value: str) -> None: + ... + + +async def task_2a(a: int, b: str) -> None: + ... + + +async def task_2b(a: str, b: int) -> None: + ... + + +async def task_2c(a: str, b: int, optional: bool = False) -> None: + ... + + +async def task_requires_kw(a: int, *, b: bool) -> None: + ... + + +async def task_startable_1( + a: str, + *, + task_status: TaskStatus[bool] = TASK_STATUS_IGNORED, +) -> None: + ... + + +async def task_startable_2( + a: str, + b: float, + *, + task_status: TaskStatus[bool] = TASK_STATUS_IGNORED, +) -> None: + ... + + +async def task_requires_start(*, task_status: TaskStatus[str]) -> None: + """Check a function requiring start() to be used.""" + + +async def task_pos_or_kw(value: str, task_status: TaskStatus[int]) -> None: + """Check a function which doesn't use the *-syntax works.""" + ... + + +def check_start_soon(nursery: Nursery) -> None: + """start_soon() functionality.""" + nursery.start_soon(task_0) + nursery.start_soon(task_1a) # type: ignore + nursery.start_soon(task_2b) # type: ignore + + nursery.start_soon(task_0, 45) # type: ignore + nursery.start_soon(task_1a, 32) + nursery.start_soon(task_1b, 32) # type: ignore + nursery.start_soon(task_1a, "abc") # type: ignore + nursery.start_soon(task_1b, "abc") + + nursery.start_soon(task_2b, "abc") # type: ignore + nursery.start_soon(task_2a, 38, "46") + nursery.start_soon(task_2c, "abc", 12, True) + + nursery.start_soon(task_2c, "abc", 12) + task_2c_cast: Callable[ + [str, int], Awaitable[object] + ] = task_2c # The assignment makes it work. + nursery.start_soon(task_2c_cast, "abc", 12) + + nursery.start_soon(task_requires_kw, 12, True) # type: ignore + # Tasks following the start() API can be made to work. + nursery.start_soon(task_startable_1, "cdf") diff --git a/src/trio/_dtls.py b/src/trio/_dtls.py index 7d1969bab4..4ad6d21751 100644 --- a/src/trio/_dtls.py +++ b/src/trio/_dtls.py @@ -41,11 +41,12 @@ # See DTLSEndpoint.__init__ for why this is imported here from OpenSSL import SSL # noqa: TCH004 - from OpenSSL.SSL import Context - from typing_extensions import Self, TypeAlias + from typing_extensions import Self, TypeAlias, TypeVarTuple, Unpack from trio.socket import SocketType + PosArgsT = TypeVarTuple("PosArgsT") + MAX_UDP_PACKET_SIZE = 65527 @@ -830,7 +831,12 @@ class DTLSChannel(trio.abc.Channel[bytes], metaclass=NoPublicConstructor): """ - def __init__(self, endpoint: DTLSEndpoint, peer_address: Any, ctx: Context): + def __init__( + self, + endpoint: DTLSEndpoint, + peer_address: Any, + ctx: SSL.Context, + ) -> None: self.endpoint = endpoint self.peer_address = peer_address self._packets_dropped_in_trio = 0 @@ -1176,7 +1182,12 @@ class DTLSEndpoint: """ - def __init__(self, socket: SocketType, *, incoming_packets_buffer: int = 10): + def __init__( + self, + socket: SocketType, + *, + incoming_packets_buffer: int = 10, + ) -> None: # We do this lazily on first construction, so only people who actually use DTLS # have to install PyOpenSSL. global SSL @@ -1197,7 +1208,7 @@ def __init__(self, socket: SocketType, *, incoming_packets_buffer: int = 10): # old connection. # {remote address: DTLSChannel} self._streams: WeakValueDictionary[Any, DTLSChannel] = WeakValueDictionary() - self._listening_context: Context | None = None + self._listening_context: SSL.Context | None = None self._listening_key: bytes | None = None self._incoming_connections_q = _Queue[DTLSChannel](float("inf")) self._send_lock = trio.Lock() @@ -1258,14 +1269,11 @@ def _check_closed(self) -> None: if self._closed: raise trio.ClosedResourceError - # async_fn cannot be typed with ParamSpec, since we don't accept - # kwargs. Can be typed with TypeVarTuple once it's fully supported - # in mypy. async def serve( self, - ssl_context: Context, - async_fn: Callable[..., Awaitable[object]], - *args: Any, + ssl_context: SSL.Context, + async_fn: Callable[[DTLSChannel, Unpack[PosArgsT]], Awaitable[object]], + *args: Unpack[PosArgsT], task_status: trio.TaskStatus[None] = trio.TASK_STATUS_IGNORED, ) -> None: """Listen for incoming connections, and spawn a handler for each using an @@ -1294,6 +1302,7 @@ async def handler(dtls_channel): incoming connections. async_fn: The handler function that will be invoked for each incoming connection. + *args: Additional arguments to pass to the handler function. """ self._check_closed() @@ -1324,7 +1333,11 @@ async def handler_wrapper(stream: DTLSChannel) -> None: finally: self._listening_context = None - def connect(self, address: tuple[str, int], ssl_context: Context) -> DTLSChannel: + def connect( + self, + address: tuple[str, int], + ssl_context: SSL.Context, + ) -> DTLSChannel: """Initiate an outgoing DTLS connection. Notice that this is a synchronous method. That's because it doesn't actually diff --git a/src/trio/_subprocess.py b/src/trio/_subprocess.py index b9bdff1a75..4d325b2fb5 100644 --- a/src/trio/_subprocess.py +++ b/src/trio/_subprocess.py @@ -764,14 +764,17 @@ async def read_output( proc = await open_process(command, **options) # type: ignore[arg-type, call-overload, unused-ignore] try: if input is not None: + assert proc.stdin is not None nursery.start_soon(feed_input, proc.stdin) proc.stdin = None proc.stdio = None if capture_stdout: + assert proc.stdout is not None nursery.start_soon(read_output, proc.stdout, stdout_chunks) proc.stdout = None proc.stdio = None if capture_stderr: + assert proc.stderr is not None nursery.start_soon(read_output, proc.stderr, stderr_chunks) proc.stderr = None task_status.started(proc) diff --git a/src/trio/_tests/test_channel.py b/src/trio/_tests/test_channel.py index c81933b6b7..ae555715cb 100644 --- a/src/trio/_tests/test_channel.py +++ b/src/trio/_tests/test_channel.py @@ -151,17 +151,17 @@ async def receive_block(r: trio.MemoryReceiveChannel[int]) -> None: with pytest.raises(trio.ClosedResourceError): await r.receive() - s, r = open_memory_channel[None](0) + s2, r2 = open_memory_channel[int](0) async with trio.open_nursery() as nursery: - nursery.start_soon(receive_block, r) + nursery.start_soon(receive_block, r2) await wait_all_tasks_blocked() - await r.aclose() + await r2.aclose() # and it's persistent with pytest.raises(trio.ClosedResourceError): - r.receive_nowait() + r2.receive_nowait() with pytest.raises(trio.ClosedResourceError): - await r.receive() + await r2.receive() async def test_close_sync() -> None: @@ -204,7 +204,7 @@ async def send_block( await s.send(None) # closing receive -> other receive gets ClosedResourceError - async def receive_block(r: trio.MemoryReceiveChannel[int]) -> None: + async def receive_block(r: trio.MemoryReceiveChannel[None]) -> None: with pytest.raises(trio.ClosedResourceError): await r.receive() @@ -366,9 +366,9 @@ async def test_channel_fairness() -> None: # But if someone else is waiting to receive, then they "own" the item we # send, so we can't receive it (even though we run first): - result = None + result: int | None = None - async def do_receive(r: trio.MemoryReceiveChannel[int]) -> None: + async def do_receive(r: trio.MemoryReceiveChannel[int | None]) -> None: nonlocal result result = await r.receive() diff --git a/src/trio/_tests/test_scheduler_determinism.py b/src/trio/_tests/test_scheduler_determinism.py index bf0eec3d39..02733226cf 100644 --- a/src/trio/_tests/test_scheduler_determinism.py +++ b/src/trio/_tests/test_scheduler_determinism.py @@ -19,7 +19,7 @@ async def tracer(name: str) -> None: async with trio.open_nursery() as nursery: for i in range(5): - nursery.start_soon(tracer, i) + nursery.start_soon(tracer, str(i)) return tuple(trace) diff --git a/src/trio/_tests/test_subprocess.py b/src/trio/_tests/test_subprocess.py index 27642f1775..8099ea446d 100644 --- a/src/trio/_tests/test_subprocess.py +++ b/src/trio/_tests/test_subprocess.py @@ -42,7 +42,7 @@ from typing_extensions import TypeAlias - from .._abc import Stream + from .._abc import ReceiveStream if sys.platform == "win32": SignalType: TypeAlias = None @@ -220,12 +220,15 @@ async def feed_input() -> None: await proc.stdin.send_all(msg) await proc.stdin.aclose() - async def check_output(stream: Stream, expected: bytes) -> None: + async def check_output(stream: ReceiveStream, expected: bytes) -> None: seen = bytearray() async for chunk in stream: seen += chunk assert seen == expected + assert proc.stdout is not None + assert proc.stderr is not None + async with _core.open_nursery() as nursery: # fail eventually if something is broken nursery.cancel_scope.deadline = _core.current_time() + 30.0 @@ -270,7 +273,9 @@ async def test_interactive(background_process: BackgroundProcessType) -> None: async def expect(idx: int, request: int) -> None: async with _core.open_nursery() as nursery: - async def drain_one(stream: Stream, count: int, digit: int) -> None: + async def drain_one( + stream: ReceiveStream, count: int, digit: int + ) -> None: while count > 0: result = await stream.receive_some(count) assert result == (f"{digit}".encode() * len(result)) @@ -278,6 +283,8 @@ async def drain_one(stream: Stream, count: int, digit: int) -> None: assert count == 0 assert await stream.receive_some(len(newline)) == newline + assert proc.stdout is not None + assert proc.stderr is not None nursery.start_soon(drain_one, proc.stdout, request, idx * 2) nursery.start_soon(drain_one, proc.stderr, request * 2, idx * 2 + 1) @@ -610,9 +617,7 @@ async def test_warn_on_cancel_SIGKILL_escalation( async def test_run_process_background_fail() -> None: with pytest.raises(subprocess.CalledProcessError): async with _core.open_nursery() as nursery: - proc: subprocess.CompletedProcess[bytes] = await nursery.start( - run_process, EXIT_FALSE - ) + proc: Process = await nursery.start(run_process, EXIT_FALSE) assert proc.returncode == 1 diff --git a/src/trio/_tests/test_sync.py b/src/trio/_tests/test_sync.py index 9179c8a5ae..dd78d2e82f 100644 --- a/src/trio/_tests/test_sync.py +++ b/src/trio/_tests/test_sync.py @@ -546,7 +546,7 @@ async def test_generic_lock_fifo_fairness(lock_factory: LockFactory) -> None: record = [] LOOPS = 5 - async def loopy(name: str, lock_like: LockLike) -> None: + async def loopy(name: int, lock_like: LockLike) -> None: # Record the order each task was initially scheduled in initial_order.append(name) for _ in range(LOOPS): diff --git a/src/trio/_tests/type_tests/path.py b/src/trio/_tests/type_tests/path.py index 321fd1043b..7c8c6de4a2 100644 --- a/src/trio/_tests/type_tests/path.py +++ b/src/trio/_tests/type_tests/path.py @@ -6,7 +6,7 @@ from typing import IO, Any, BinaryIO, List, Tuple import trio -from trio._path import _AsyncIOWrapper +from trio._path import _AsyncIOWrapper # pyright: ignore[reportPrivateUsage] from typing_extensions import assert_type diff --git a/src/trio/_tools/gen_exports.py b/src/trio/_tools/gen_exports.py index 3227a06018..7156d7015d 100755 --- a/src/trio/_tools/gen_exports.py +++ b/src/trio/_tools/gen_exports.py @@ -220,10 +220,12 @@ def gen_public_wrappers_source(file: File) -> str: generated = ["".join(header)] source = astor.code_to_ast.parse_file(file.path) + method_names = [] for method in get_public_methods(source): # Remove self from arguments assert method.args.args[0].arg == "self" del method.args.args[0] + method_names.append(method.name) for dec in method.decorator_list: # pragma: no cover if isinstance(dec, ast.Name) and dec.id == "contextmanager": @@ -263,6 +265,10 @@ def gen_public_wrappers_source(file: File) -> str: # Append the snippet to the corresponding module generated.append(snippet) + + method_names.sort() + # Insert after the header, before function definitions + generated.insert(1, f"__all__ = {method_names!r}") return "\n\n".join(generated) @@ -346,7 +352,7 @@ def main() -> None: # pragma: no cover IMPORTS_RUN = """\ from collections.abc import Awaitable, Callable -from typing import Any +from typing import Any, TYPE_CHECKING from outcome import Outcome import contextvars @@ -354,6 +360,10 @@ def main() -> None: # pragma: no cover from ._run import _NO_SEND, RunStatistics, Task from ._entry_queue import TrioToken from .._abc import Clock + +if TYPE_CHECKING: + from typing_extensions import Unpack + from ._run import PosArgT """ IMPORTS_INSTRUMENT = """\ from ._instrumentation import Instrument diff --git a/src/trio/_util.py b/src/trio/_util.py index 5f514cb532..f4d4195128 100644 --- a/src/trio/_util.py +++ b/src/trio/_util.py @@ -21,9 +21,10 @@ if t.TYPE_CHECKING: from types import AsyncGeneratorType, TracebackType - from typing_extensions import ParamSpec, Self + from typing_extensions import ParamSpec, Self, TypeVarTuple, Unpack ArgsT = ParamSpec("ArgsT") + PosArgsT = TypeVarTuple("PosArgsT") if t.TYPE_CHECKING: @@ -102,9 +103,9 @@ def is_main_thread() -> bool: # Call the function and get the coroutine object, while giving helpful # errors for common mistakes. Returns coroutine object. ###### -# TODO: Use TypeVarTuple here. def coroutine_or_error( - async_fn: t.Callable[..., t.Awaitable[RetT]], *args: t.Any + async_fn: t.Callable[[Unpack[PosArgsT]], t.Awaitable[RetT]], + *args: Unpack[PosArgsT], ) -> collections.abc.Coroutine[object, t.NoReturn, RetT]: def _return_value_looks_like_wrong_library(value: object) -> bool: # Returned by legacy @asyncio.coroutine functions, which includes