Skip to content

Commit

Permalink
Use Typevar defaults for TaskStatus and Matcher (#3019)
Browse files Browse the repository at this point in the history
* Default TaskStatus to use None if unspecified

* Default Matcher to BaseException if unspecified

* Update Sphinx logic for new typevar name

* Add some type tests for defaulted typevar classes
  • Loading branch information
TeamSpen210 committed Jun 19, 2024
1 parent b93d8a6 commit 26cc6ee
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 34 deletions.
14 changes: 9 additions & 5 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,10 @@ def autodoc_process_signature(
# name.
assert isinstance(obj, property), obj
assert isinstance(obj.fget, types.FunctionType), obj.fget
assert obj.fget.__annotations__["return"] == "type[E]", obj.fget.__annotations__
obj.fget.__annotations__["return"] = "type[~trio.testing._raises_group.E]"
assert (
obj.fget.__annotations__["return"] == "type[MatchE]"
), obj.fget.__annotations__
obj.fget.__annotations__["return"] = "type[~trio.testing._raises_group.MatchE]"
if signature is not None:
signature = signature.replace("~_contextvars.Context", "~contextvars.Context")
if name == "trio.lowlevel.RunVar": # Typevar is not useful here.
Expand All @@ -123,13 +125,15 @@ def autodoc_process_signature(
# Strip the type from the union, make it look like = ...
signature = signature.replace(" | type[trio._core._local._NoValue]", "")
signature = signature.replace("<class 'trio._core._local._NoValue'>", "...")
if (
name in ("trio.testing.RaisesGroup", "trio.testing.Matcher")
and "+E" in signature
if name in ("trio.testing.RaisesGroup", "trio.testing.Matcher") and (
"+E" in signature or "+MatchE" in signature
):
# This typevar being covariant isn't handled correctly in some cases, strip the +
# and insert the fully-qualified name.
signature = signature.replace("+E", "~trio.testing._raises_group.E")
signature = signature.replace(
"+MatchE", "~trio.testing._raises_group.MatchE"
)
if "DTLS" in name:
signature = signature.replace("SSL.Context", "OpenSSL.SSL.Context")
# Don't specify PathLike[str] | PathLike[bytes], this is just for humans.
Expand Down
19 changes: 11 additions & 8 deletions src/trio/_core/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
Final,
NoReturn,
Protocol,
TypeVar,
cast,
overload,
)
Expand Down Expand Up @@ -54,12 +53,6 @@
if sys.version_info < (3, 11):
from exceptiongroup import BaseExceptionGroup

FnT = TypeVar("FnT", bound="Callable[..., Any]")
StatusT = TypeVar("StatusT")
StatusT_co = TypeVar("StatusT_co", covariant=True)
StatusT_contra = TypeVar("StatusT_contra", contravariant=True)
RetT = TypeVar("RetT")


if TYPE_CHECKING:
import contextvars
Expand All @@ -77,9 +70,19 @@
# for some strange reason Sphinx works with outcome.Outcome, but not Outcome, in
# start_guest_run. Same with types.FrameType in iter_await_frames
import outcome
from typing_extensions import Self, TypeVarTuple, Unpack
from typing_extensions import Self, TypeVar, TypeVarTuple, Unpack

PosArgT = TypeVarTuple("PosArgT")
StatusT = TypeVar("StatusT", default=None)
StatusT_contra = TypeVar("StatusT_contra", contravariant=True, default=None)
else:
from typing import TypeVar

StatusT = TypeVar("StatusT")
StatusT_contra = TypeVar("StatusT_contra", contravariant=True)

FnT = TypeVar("FnT", bound="Callable[..., Any]")
RetT = TypeVar("RetT")


DEADLINE_HEAP_MIN_PRUNE_THRESHOLD: Final = 1000
Expand Down
8 changes: 8 additions & 0 deletions src/trio/_tests/type_tests/raisesgroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@ def check_inheritance_and_assignments() -> None:
assert a


def check_matcher_typevar_default(e: Matcher) -> object:
assert e.exception_type is not None
exc: type[BaseException] = e.exception_type
# this would previously pass, as the type would be `Any`
e.exception_type().blah() # type: ignore
return exc # Silence Pyright unused var warning


def check_basic_contextmanager() -> None:
# One level of Group is correctly translated - except it's a BaseExceptionGroup
# instead of an ExceptionGroup.
Expand Down
29 changes: 29 additions & 0 deletions src/trio/_tests/type_tests/task_status.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""Check that started() can only be called for TaskStatus[None]."""

from trio import TaskStatus
from typing_extensions import assert_type


async def check_status(
none_status_explicit: TaskStatus[None],
none_status_implicit: TaskStatus,
int_status: TaskStatus[int],
) -> None:
assert_type(none_status_explicit, TaskStatus[None])
assert_type(none_status_implicit, TaskStatus[None]) # Default typevar
assert_type(int_status, TaskStatus[int])

# Omitting the parameter is only allowed for None.
none_status_explicit.started()
none_status_implicit.started()
int_status.started() # type: ignore

# Explicit None is allowed.
none_status_explicit.started(None)
none_status_implicit.started(None)
int_status.started(None) # type: ignore

none_status_explicit.started(42) # type: ignore
none_status_implicit.started(42) # type: ignore
int_status.started(42)
int_status.started(True)
55 changes: 34 additions & 21 deletions src/trio/testing/_raises_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
Literal,
Pattern,
Sequence,
TypeVar,
cast,
overload,
)
Expand All @@ -26,43 +25,57 @@
import types

from _pytest._code.code import ExceptionChainRepr, ReprExceptionInfo, Traceback
from typing_extensions import TypeGuard
from typing_extensions import TypeGuard, TypeVar

if sys.version_info < (3, 11):
from exceptiongroup import BaseExceptionGroup
MatchE = TypeVar(
"MatchE", bound=BaseException, default=BaseException, covariant=True
)
else:
from typing import TypeVar

MatchE = TypeVar("MatchE", bound=BaseException, covariant=True)
# RaisesGroup doesn't work with a default.
E = TypeVar("E", bound=BaseException, covariant=True)
# These two typevars are special cased in sphinx config to workaround lookup bugs.

if sys.version_info < (3, 11):
from exceptiongroup import BaseExceptionGroup


@final
class _ExceptionInfo(Generic[E]):
class _ExceptionInfo(Generic[MatchE]):
"""Minimal re-implementation of pytest.ExceptionInfo, only used if pytest is not available. Supports a subset of its features necessary for functionality of :class:`trio.testing.RaisesGroup` and :class:`trio.testing.Matcher`."""

_excinfo: tuple[type[E], E, types.TracebackType] | None
_excinfo: tuple[type[MatchE], MatchE, types.TracebackType] | None

def __init__(self, excinfo: tuple[type[E], E, types.TracebackType] | None):
def __init__(
self, excinfo: tuple[type[MatchE], MatchE, types.TracebackType] | None
):
self._excinfo = excinfo

def fill_unfilled(self, exc_info: tuple[type[E], E, types.TracebackType]) -> None:
def fill_unfilled(
self, exc_info: tuple[type[MatchE], MatchE, types.TracebackType]
) -> None:
"""Fill an unfilled ExceptionInfo created with ``for_later()``."""
assert self._excinfo is None, "ExceptionInfo was already filled"
self._excinfo = exc_info

@classmethod
def for_later(cls) -> _ExceptionInfo[E]:
def for_later(cls) -> _ExceptionInfo[MatchE]:
"""Return an unfilled ExceptionInfo."""
return cls(None)

# Note, special cased in sphinx config, since "type" conflicts.
@property
def type(self) -> type[E]:
def type(self) -> type[MatchE]:
"""The exception class."""
assert (
self._excinfo is not None
), ".type can only be used after the context manager exits"
return self._excinfo[0]

@property
def value(self) -> E:
def value(self) -> MatchE:
"""The exception value."""
assert (
self._excinfo is not None
Expand Down Expand Up @@ -95,7 +108,7 @@ def getrepr(
showlocals: bool = False,
style: str = "long",
abspath: bool = False,
tbfilter: bool | Callable[[_ExceptionInfo[BaseException]], Traceback] = True,
tbfilter: bool | Callable[[_ExceptionInfo], Traceback] = True,
funcargs: bool = False,
truncate_locals: bool = True,
chain: bool = True,
Expand Down Expand Up @@ -135,7 +148,7 @@ def _stringify_exception(exc: BaseException) -> str:


@final
class Matcher(Generic[E]):
class Matcher(Generic[MatchE]):
"""Helper class to be used together with RaisesGroups when you want to specify requirements on sub-exceptions. Only specifying the type is redundant, and it's also unnecessary when the type is a nested `RaisesGroup` since it supports the same arguments.
The type is checked with `isinstance`, and does not need to be an exact match. If that is wanted you can use the ``check`` parameter.
:meth:`trio.testing.Matcher.matches` can also be used standalone to check individual exceptions.
Expand All @@ -154,10 +167,10 @@ class Matcher(Generic[E]):
# At least one of the three parameters must be passed.
@overload
def __init__(
self: Matcher[E],
exception_type: type[E],
self: Matcher[MatchE],
exception_type: type[MatchE],
match: str | Pattern[str] = ...,
check: Callable[[E], bool] = ...,
check: Callable[[MatchE], bool] = ...,
): ...

@overload
Expand All @@ -174,9 +187,9 @@ def __init__(self, *, check: Callable[[BaseException], bool]): ...

def __init__(
self,
exception_type: type[E] | None = None,
exception_type: type[MatchE] | None = None,
match: str | Pattern[str] | None = None,
check: Callable[[E], bool] | None = None,
check: Callable[[MatchE], bool] | None = None,
):
if exception_type is None and match is None and check is None:
raise ValueError("You must specify at least one parameter to match on.")
Expand All @@ -192,7 +205,7 @@ def __init__(
self.match = match
self.check = check

def matches(self, exception: BaseException) -> TypeGuard[E]:
def matches(self, exception: BaseException) -> TypeGuard[MatchE]:
"""Check if an exception matches the requirements of this Matcher.
Examples::
Expand Down Expand Up @@ -220,7 +233,7 @@ def matches(self, exception: BaseException) -> TypeGuard[E]:
return False
# If exception_type is None check() accepts BaseException.
# If non-none, we have done an isinstance check above.
if self.check is not None and not self.check(cast(E, exception)):
if self.check is not None and not self.check(cast(MatchE, exception)):
return False
return True

Expand Down Expand Up @@ -254,8 +267,8 @@ def __str__(self) -> str:
# We lie to type checkers that we inherit, so excinfo.value and sub-exceptiongroups can be treated as ExceptionGroups
if TYPE_CHECKING:
SuperClass = BaseExceptionGroup
# Inheriting at runtime leads to a series of TypeErrors, so we do not want to do that.
else:
# At runtime, use a redundant Generic base class which effectively gets ignored.
SuperClass = Generic


Expand Down

0 comments on commit 26cc6ee

Please sign in to comment.