Skip to content

Commit

Permalink
Use TypeVarTuple in our APIs (#2881)
Browse files Browse the repository at this point in the history
* Use TypeVarTuple in various functions, except for Nursery.start().
  That isn't handled by type checkers well yet.
* Fix docs failure
* Make gen_exports create an __all__ list

Co-authored-by: CoolCat467 <52022020+CoolCat467@users.noreply.github.com>
Co-authored-by: EXPLOSION <git@helvetica.moe>
  • Loading branch information
3 people committed Dec 13, 2023
1 parent 31b87ad commit 597e345
Show file tree
Hide file tree
Showing 24 changed files with 261 additions and 74 deletions.
1 change: 1 addition & 0 deletions check.sh
Expand Up @@ -110,6 +110,7 @@ if [ $PYRIGHT -ne 0 ]; then
fi

pyright src/trio/_tests/type_tests || EXIT_STATUS=$?
pyright src/trio/_core/_tests/type_tests || EXIT_STATUS=$?
echo "::endgroup::"

# Finally, leave a really clear warning of any issues and exit
Expand Down
3 changes: 2 additions & 1 deletion docs/source/conf.py
Expand Up @@ -73,7 +73,6 @@
# aliasing doesn't actually fix the warning for types.FrameType, but displaying
# "types.FrameType" is more helpful than just "frame"
"FrameType": "types.FrameType",
"Context": "OpenSSL.SSL.Context",
# SSLListener.accept's return type is seen as trio._ssl.SSLStream
"SSLStream": "trio.SSLStream",
}
Expand All @@ -91,6 +90,8 @@ def autodoc_process_signature(
# Strip the type from the union, make it look like = ...
signature = signature.replace(" | type[trio._core._local._NoValue]", "")
signature = signature.replace("<class 'trio._core._local._NoValue'>", "...")
if "DTLS" in name:
signature = signature.replace("SSL.Context", "OpenSSL.SSL.Context")
# Don't specify PathLike[str] | PathLike[bytes], this is just for humans.
signature = signature.replace("StrOrBytesPath", "str | bytes | os.PathLike")

Expand Down
1 change: 1 addition & 0 deletions newsfragments/2881.feature.rst
@@ -0,0 +1 @@
`TypeVarTuple <https://docs.python.org/3.12/library/typing.html#typing.TypeVarTuple>`_ is now used to fully type :meth:`nursery.start_soon() <trio.Nursery.start_soon>`, :func:`trio.run()`, :func:`trio.to_thread.run_sync()`, and other similar functions accepting ``(func, *args)``. This means type checkers will be able to verify types are used correctly. :meth:`nursery.start() <trio.Nursery.start>` is not fully typed yet however.
4 changes: 4 additions & 0 deletions pyproject.toml
Expand Up @@ -157,6 +157,8 @@ check_untyped_defs = true

[tool.pyright]
pythonVersion = "3.8"
reportUnnecessaryTypeIgnoreComment = true
typeCheckingMode = "strict"

[tool.pytest.ini_options]
addopts = ["--strict-markers", "--strict-config", "-p trio._tests.pytest_plugin"]
Expand Down Expand Up @@ -235,6 +237,8 @@ omit = [
"*/trio/_core/_tests/test_multierror_scripts/*",
# Omit the generated files in trio/_core starting with _generated_
"*/trio/_core/_generated_*",
# Type tests aren't intended to be run, just passed to type checkers.
"*/type_tests/*",
]
# The test suite spawns subprocesses to test some stuff, so make sure
# this doesn't corrupt the coverage files
Expand Down
21 changes: 15 additions & 6 deletions src/trio/_core/_entry_queue.py
Expand Up @@ -2,18 +2,21 @@

import threading
from collections import deque
from typing import Callable, Iterable, NoReturn, Tuple
from typing import TYPE_CHECKING, Callable, NoReturn, Tuple

import attr

from .. import _core
from .._util import NoPublicConstructor, final
from ._wakeup_socketpair import WakeupSocketpair

# TODO: Type with TypeVarTuple, at least to an extent where it makes
# the public interface safe.
if TYPE_CHECKING:
from typing_extensions import TypeVarTuple, Unpack

PosArgsT = TypeVarTuple("PosArgsT")

Function = Callable[..., object]
Job = Tuple[Function, Iterable[object]]
Job = Tuple[Function, Tuple[object, ...]]


@attr.s(slots=True)
Expand Down Expand Up @@ -122,7 +125,10 @@ def size(self) -> int:
return len(self.queue) + len(self.idempotent_queue)

def run_sync_soon(
self, sync_fn: Function, *args: object, idempotent: bool = False
self,
sync_fn: Callable[[Unpack[PosArgsT]], object],
*args: Unpack[PosArgsT],
idempotent: bool = False,
) -> None:
with self.lock:
if self.done:
Expand Down Expand Up @@ -163,7 +169,10 @@ class TrioToken(metaclass=NoPublicConstructor):
_reentry_queue: EntryQueue = attr.ib()

def run_sync_soon(
self, sync_fn: Function, *args: object, idempotent: bool = False
self,
sync_fn: Callable[[Unpack[PosArgsT]], object],
*args: Unpack[PosArgsT],
idempotent: bool = False,
) -> None:
"""Schedule a call to ``sync_fn(*args)`` to occur in the context of a
Trio task.
Expand Down
2 changes: 2 additions & 0 deletions src/trio/_core/_generated_instrumentation.py
Expand Up @@ -11,6 +11,8 @@
if TYPE_CHECKING:
from ._instrumentation import Instrument

__all__ = ["add_instrument", "remove_instrument"]


def add_instrument(instrument: Instrument) -> None:
"""Start instrumenting the current run loop with the given instrument.
Expand Down
3 changes: 3 additions & 0 deletions src/trio/_core/_generated_io_epoll.py
Expand Up @@ -15,6 +15,9 @@
assert not TYPE_CHECKING or sys.platform == "linux"


__all__ = ["notify_closing", "wait_readable", "wait_writable"]


async def wait_readable(fd: (int | _HasFileNo)) -> None:
"""Block until the kernel reports that the given object is readable.
Expand Down
10 changes: 10 additions & 0 deletions src/trio/_core/_generated_io_kqueue.py
Expand Up @@ -19,6 +19,16 @@
assert not TYPE_CHECKING or sys.platform == "darwin"


__all__ = [
"current_kqueue",
"monitor_kevent",
"notify_closing",
"wait_kevent",
"wait_readable",
"wait_writable",
]


def current_kqueue() -> select.kqueue:
"""TODO: these are implemented, but are currently more of a sketch than
anything real. See `#26
Expand Down
13 changes: 13 additions & 0 deletions src/trio/_core/_generated_io_windows.py
Expand Up @@ -19,6 +19,19 @@
assert not TYPE_CHECKING or sys.platform == "win32"


__all__ = [
"current_iocp",
"monitor_completion_key",
"notify_closing",
"readinto_overlapped",
"register_with_iocp",
"wait_overlapped",
"wait_readable",
"wait_writable",
"write_overlapped",
]


async def wait_readable(sock: (_HasFileNo | int)) -> None:
"""Block until the kernel reports that the given object is readable.
Expand Down
18 changes: 16 additions & 2 deletions src/trio/_core/_generated_run.py
Expand Up @@ -13,9 +13,23 @@
from collections.abc import Awaitable, Callable

from outcome import Outcome
from typing_extensions import Unpack

from .._abc import Clock
from ._entry_queue import TrioToken
from ._run import PosArgT


__all__ = [
"current_clock",
"current_root_task",
"current_statistics",
"current_time",
"current_trio_token",
"reschedule",
"spawn_system_task",
"wait_all_tasks_blocked",
]


def current_statistics() -> RunStatistics:
Expand Down Expand Up @@ -113,8 +127,8 @@ def reschedule(task: Task, next_send: Outcome[Any] = _NO_SEND) -> None:


def spawn_system_task(
async_fn: Callable[..., Awaitable[object]],
*args: object,
async_fn: Callable[[Unpack[PosArgT]], Awaitable[object]],
*args: Unpack[PosArgT],
name: object = None,
context: (contextvars.Context | None) = None,
) -> Task:
Expand Down
63 changes: 36 additions & 27 deletions src/trio/_core/_run.py
Expand Up @@ -53,6 +53,12 @@
if sys.version_info < (3, 11):
from exceptiongroup import BaseExceptionGroup

FnT = TypeVar("FnT", bound="Callable[..., Any]")
StatusT = TypeVar("StatusT")
StatusT_co = TypeVar("StatusT_co", covariant=True)
StatusT_contra = TypeVar("StatusT_contra", contravariant=True)
RetT = TypeVar("RetT")


if TYPE_CHECKING:
import contextvars
Expand All @@ -70,19 +76,25 @@
# for some strange reason Sphinx works with outcome.Outcome, but not Outcome, in
# start_guest_run. Same with types.FrameType in iter_await_frames
import outcome
from typing_extensions import Self
from typing_extensions import Self, TypeVarTuple, Unpack

PosArgT = TypeVarTuple("PosArgT")

# Needs to be guarded, since Unpack[] would be evaluated at runtime.
class _NurseryStartFunc(Protocol[Unpack[PosArgT], StatusT_co]):
"""Type of functions passed to `nursery.start() <trio.Nursery.start>`."""

def __call__(
self, *args: Unpack[PosArgT], task_status: TaskStatus[StatusT_co]
) -> Awaitable[object]:
...


DEADLINE_HEAP_MIN_PRUNE_THRESHOLD: Final = 1000

# Passed as a sentinel
_NO_SEND: Final[Outcome[Any]] = cast("Outcome[Any]", object())

FnT = TypeVar("FnT", bound="Callable[..., Any]")
StatusT = TypeVar("StatusT")
StatusT_co = TypeVar("StatusT_co", covariant=True)
StatusT_contra = TypeVar("StatusT_contra", contravariant=True)
RetT = TypeVar("RetT")


@final
class _NoStatus(metaclass=NoPublicConstructor):
Expand Down Expand Up @@ -1119,9 +1131,8 @@ def aborted(raise_cancel: _core.RaiseCancelT) -> Abort:

def start_soon(
self,
# TODO: TypeVarTuple
async_fn: Callable[..., Awaitable[object]],
*args: object,
async_fn: Callable[[Unpack[PosArgT]], Awaitable[object]],
*args: Unpack[PosArgT],
name: object = None,
) -> None:
"""Creates a child task, scheduling ``await async_fn(*args)``.
Expand Down Expand Up @@ -1170,7 +1181,7 @@ async def start(
async_fn: Callable[..., Awaitable[object]],
*args: object,
name: object = None,
) -> StatusT:
) -> Any:
r"""Creates and initializes a child task.
Like :meth:`start_soon`, but blocks until the new task has
Expand Down Expand Up @@ -1219,7 +1230,7 @@ async def async_fn(arg1, arg2, *, task_status=trio.TASK_STATUS_IGNORED):
# `run` option, which would cause it to wrap a pre-started()
# exception in an extra ExceptionGroup. See #2611.
async with open_nursery(strict_exception_groups=False) as old_nursery:
task_status: _TaskStatus[StatusT] = _TaskStatus(old_nursery, self)
task_status: _TaskStatus[Any] = _TaskStatus(old_nursery, self)
thunk = functools.partial(async_fn, task_status=task_status)
task = GLOBAL_RUN_CONTEXT.runner.spawn_impl(
thunk, args, old_nursery, name
Expand All @@ -1232,7 +1243,7 @@ async def async_fn(arg1, arg2, *, task_status=trio.TASK_STATUS_IGNORED):
# (Any exceptions propagate directly out of the above.)
if task_status._value is _NoStatus:
raise RuntimeError("child exited without calling task_status.started()")
return task_status._value # type: ignore[return-value] # Mypy doesn't narrow yet.
return task_status._value
finally:
self._pending_starts -= 1
self._check_nursery_closed()
Expand Down Expand Up @@ -1690,9 +1701,8 @@ def reschedule( # type: ignore[misc]

def spawn_impl(
self,
# TODO: TypeVarTuple
async_fn: Callable[..., Awaitable[object]],
args: tuple[object, ...],
async_fn: Callable[[Unpack[PosArgT]], Awaitable[object]],
args: tuple[Unpack[PosArgT]],
nursery: Nursery | None,
name: object,
*,
Expand Down Expand Up @@ -1721,7 +1731,8 @@ def spawn_impl(
# Call the function and get the coroutine object, while giving helpful
# errors for common mistakes.
######
coro = context.run(coroutine_or_error, async_fn, *args)
# TypeVarTuple passed into ParamSpec function confuses Mypy.
coro = context.run(coroutine_or_error, async_fn, *args) # type: ignore[arg-type]

if name is None:
name = async_fn
Expand Down Expand Up @@ -1808,12 +1819,11 @@ def task_exited(self, task: Task, outcome: Outcome[Any]) -> None:
# System tasks and init
################

@_public # Type-ignore due to use of Any here.
def spawn_system_task( # type: ignore[misc]
@_public
def spawn_system_task(
self,
# TODO: TypeVarTuple
async_fn: Callable[..., Awaitable[object]],
*args: object,
async_fn: Callable[[Unpack[PosArgT]], Awaitable[object]],
*args: Unpack[PosArgT],
name: object = None,
context: contextvars.Context | None = None,
) -> Task:
Expand Down Expand Up @@ -1878,10 +1888,9 @@ def spawn_system_task( # type: ignore[misc]
)

async def init(
# TODO: TypeVarTuple
self,
async_fn: Callable[..., Awaitable[object]],
args: tuple[object, ...],
async_fn: Callable[[Unpack[PosArgT]], Awaitable[object]],
args: tuple[Unpack[PosArgT]],
) -> None:
# run_sync_soon task runs here:
async with open_nursery() as run_sync_soon_nursery:
Expand Down Expand Up @@ -2407,8 +2416,8 @@ def my_done_callback(run_outcome):
# straight through.
def unrolled_run(
runner: Runner,
async_fn: Callable[..., object],
args: tuple[object, ...],
async_fn: Callable[[Unpack[PosArgT]], Awaitable[object]],
args: tuple[Unpack[PosArgT]],
host_uses_signal_set_wakeup_fd: bool = False,
) -> Generator[float, EventResult, None]:
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
Expand Down
2 changes: 0 additions & 2 deletions src/trio/_core/_tests/test_guest_mode.py
Expand Up @@ -658,8 +658,6 @@ async def trio_main() -> None:
# Ensure we don't pollute the thread-level context if run under
# an asyncio without contextvars support (3.6)
context = contextvars.copy_context()
if TYPE_CHECKING:
aiotrio_run(trio_main, host_uses_signal_set_wakeup_fd=True)
context.run(aiotrio_run, trio_main, host_uses_signal_set_wakeup_fd=True)

assert record == {("asyncio", "asyncio"), ("trio", "trio")}
4 changes: 2 additions & 2 deletions src/trio/_core/_tests/test_local.py
Expand Up @@ -77,8 +77,8 @@ async def task1() -> None:
t1.set("plaice")
assert t1.get() == "plaice"

async def task2(tok: str) -> None:
t1.reset(token)
async def task2(tok: RunVarToken[str]) -> None:
t1.reset(tok)

with pytest.raises(LookupError):
t1.get()
Expand Down

0 comments on commit 597e345

Please sign in to comment.