Skip to content

Commit

Permalink
Make ExceptionInfo generic in the exception type
Browse files Browse the repository at this point in the history
This way, in

    with pytest.raises(ValueError) as cm:
        ...

cm.value is a ValueError and not a BaseException.
  • Loading branch information
bluetech committed Jul 10, 2019
1 parent adbbbfc commit 9881f44
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 21 deletions.
21 changes: 13 additions & 8 deletions src/_pytest/_code/code.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
from inspect import CO_VARKEYWORDS
from traceback import format_exception_only
from types import TracebackType
from typing import Generic
from typing import Optional
from typing import Pattern
from typing import Tuple
from typing import TypeVar
from typing import Union
from weakref import ref

Expand Down Expand Up @@ -379,22 +381,25 @@ def recursionindex(self):
)


_E = TypeVar("_E", bound=BaseException)


@attr.s(repr=False)
class ExceptionInfo:
class ExceptionInfo(Generic[_E]):
""" wraps sys.exc_info() objects and offers
help for navigating the traceback.
"""

_assert_start_repr = "AssertionError('assert "

_excinfo = attr.ib(
type=Optional[Tuple["Type[BaseException]", BaseException, TracebackType]]
)
_excinfo = attr.ib(type=Optional[Tuple["Type[_E]", "_E", TracebackType]])
_striptext = attr.ib(type=str, default="")
_traceback = attr.ib(type=Optional[Traceback], default=None)

@classmethod
def from_current(cls, exprinfo: Optional[str] = None) -> "ExceptionInfo":
def from_current(
cls, exprinfo: Optional[str] = None
) -> "ExceptionInfo[BaseException]":
"""returns an ExceptionInfo matching the current traceback
.. warning::
Expand Down Expand Up @@ -422,21 +427,21 @@ def from_current(cls, exprinfo: Optional[str] = None) -> "ExceptionInfo":
return cls(tup, _striptext)

@classmethod
def for_later(cls) -> "ExceptionInfo":
def for_later(cls) -> "ExceptionInfo[_E]":
"""return an unfilled ExceptionInfo
"""
return cls(None)

@property
def type(self) -> "Type[BaseException]":
def type(self) -> "Type[_E]":
"""the exception class"""
assert (
self._excinfo is not None
), ".type can only be used after the context manager exits"
return self._excinfo[0]

@property
def value(self) -> BaseException:
def value(self) -> _E:
"""the exception value"""
assert (
self._excinfo is not None
Expand Down
33 changes: 20 additions & 13 deletions src/_pytest/python_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,13 @@
from types import TracebackType
from typing import Any
from typing import Callable
from typing import cast
from typing import Generic
from typing import Optional
from typing import overload
from typing import Pattern
from typing import Tuple
from typing import TypeVar
from typing import Union

from more_itertools.more import always_iterable
Expand Down Expand Up @@ -537,33 +540,35 @@ def _is_numpy_array(obj):

# builtin pytest.raises helper

_E = TypeVar("_E", bound=BaseException)


@overload
def raises(
expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]],
expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]],
*,
match: Optional[Union[str, Pattern]] = ... # noqa: W504 (SyntaxError in Python 3.5)
) -> "RaisesContext":
) -> "RaisesContext[_E]":
...


@overload
def raises(
expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]],
expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]],
func: Callable,
*args: Any,
match: Optional[str] = ...,
**kwargs: Any
) -> Optional[_pytest._code.ExceptionInfo]:
) -> Optional[_pytest._code.ExceptionInfo[_E]]:
...


def raises(
expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]],
expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]],
*args: Any,
match: Optional[Union[str, Pattern]] = None,
**kwargs: Any
) -> Union["RaisesContext", Optional[_pytest._code.ExceptionInfo]]:
) -> Union["RaisesContext[_E]", Optional[_pytest._code.ExceptionInfo[_E]]]:
r"""
Assert that a code block/function call raises ``expected_exception``
or raise a failure exception otherwise.
Expand Down Expand Up @@ -702,28 +707,30 @@ def raises(
try:
func(*args[1:], **kwargs)
except expected_exception:
return _pytest._code.ExceptionInfo.from_current()
# Cast to narrow the type to expected_exception (_E).
return cast(
_pytest._code.ExceptionInfo[_E],
_pytest._code.ExceptionInfo.from_current(),
)
fail(message)


raises.Exception = fail.Exception # type: ignore


class RaisesContext:
class RaisesContext(Generic[_E]):
def __init__(
self,
expected_exception: Union[
"Type[BaseException]", Tuple["Type[BaseException]", ...]
],
expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]],
message: str,
match_expr: Optional[Union[str, Pattern]] = None,
) -> None:
self.expected_exception = expected_exception
self.message = message
self.match_expr = match_expr
self.excinfo = None # type: Optional[_pytest._code.ExceptionInfo]
self.excinfo = None # type: Optional[_pytest._code.ExceptionInfo[_E]]

def __enter__(self) -> _pytest._code.ExceptionInfo:
def __enter__(self) -> _pytest._code.ExceptionInfo[_E]:
self.excinfo = _pytest._code.ExceptionInfo.for_later()
return self.excinfo

Expand Down

0 comments on commit 9881f44

Please sign in to comment.