Skip to content

Commit

Permalink
Type-annotate pytest.warns
Browse files Browse the repository at this point in the history
  • Loading branch information
bluetech committed Jul 14, 2019
1 parent d7ee3da commit 2dca68b
Showing 1 changed file with 87 additions and 23 deletions.
110 changes: 87 additions & 23 deletions src/_pytest/recwarn.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,23 @@
""" recording warnings during test function execution. """
import inspect
import re
import warnings
from types import TracebackType
from typing import Any
from typing import Callable
from typing import Iterator
from typing import List
from typing import Optional
from typing import overload
from typing import Pattern
from typing import Tuple
from typing import Union

from _pytest.fixtures import yield_fixture
from _pytest.outcomes import fail

if False: # TYPE_CHECKING
from typing import Type


@yield_fixture
def recwarn():
Expand Down Expand Up @@ -42,7 +54,32 @@ def deprecated_call(func=None, *args, **kwargs):
return warns((DeprecationWarning, PendingDeprecationWarning), *args, **kwargs)


def warns(expected_warning, *args, match=None, **kwargs):
@overload
def warns(
expected_warning: Union["Type[Warning]", Tuple["Type[Warning]", ...]],
*,
match: Optional[Union[str, Pattern]] = ...
) -> "WarningsChecker":
... # pragma: no cover


@overload
def warns(
expected_warning: Union["Type[Warning]", Tuple["Type[Warning]", ...]],
func: Callable,
*args: Any,
match: Optional[Union[str, Pattern]] = ...,
**kwargs: Any
) -> Union[Any]:
... # pragma: no cover


def warns(
expected_warning: Union["Type[Warning]", Tuple["Type[Warning]", ...]],
*args: Any,
match: Optional[Union[str, Pattern]] = None,
**kwargs: Any
) -> Union["WarningsChecker", Any]:
r"""Assert that code raises a particular class of warning.
Specifically, the parameter ``expected_warning`` can be a warning class or
Expand Down Expand Up @@ -101,81 +138,107 @@ class WarningsRecorder(warnings.catch_warnings):
def __init__(self):
super().__init__(record=True)
self._entered = False
self._list = []
self._list = [] # type: List[warnings._Record]

This comment has been minimized.

Copy link
@blueyed

blueyed Jan 14, 2020

Contributor

@bluetech
This fails since python/typeshed@c44a556, i.e. likely with the next mypy release:

Name 'warnings._Record' is not defined

This comment has been minimized.

Copy link
@bluetech

bluetech Jan 14, 2020

Author Member

Thanks for letting me know. When we update to next mypy, we'll need to change warnings._Record to warnings.WarningMessage. Seems like it's a real class present at least from py3.5, so can be used without wrapping in "".


@property
def list(self):
def list(self) -> List["warnings._Record"]:
"""The list of recorded warnings."""
return self._list

def __getitem__(self, i):
def __getitem__(self, i: int) -> "warnings._Record":
"""Get a recorded warning by index."""
return self._list[i]

def __iter__(self):
def __iter__(self) -> Iterator["warnings._Record"]:
"""Iterate through the recorded warnings."""
return iter(self._list)

def __len__(self):
def __len__(self) -> int:
"""The number of recorded warnings."""
return len(self._list)

def pop(self, cls=Warning):
def pop(self, cls: "Type[Warning]" = Warning) -> "warnings._Record":
"""Pop the first recorded warning, raise exception if not exists."""
for i, w in enumerate(self._list):
if issubclass(w.category, cls):
return self._list.pop(i)
__tracebackhide__ = True
raise AssertionError("%r not found in warning list" % cls)

def clear(self):
def clear(self) -> None:
"""Clear the list of recorded warnings."""
self._list[:] = []

def __enter__(self):
# Type ignored because it doesn't exactly warnings.catch_warnings.__enter__
# -- it returns a List but we only emulate one.
def __enter__(self) -> "WarningsRecorder": # type: ignore
if self._entered:
__tracebackhide__ = True
raise RuntimeError("Cannot enter %r twice" % self)
self._list = super().__enter__()
_list = super().__enter__()
# record=True means it's None.
assert _list is not None
self._list = _list
warnings.simplefilter("always")
return self

def __exit__(self, *exc_info):
def __exit__(
self,
exc_type: Optional["Type[BaseException]"],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> bool:
if not self._entered:
__tracebackhide__ = True
raise RuntimeError("Cannot exit %r without entering first" % self)

super().__exit__(*exc_info)
super().__exit__(exc_type, exc_val, exc_tb)

# Built-in catch_warnings does not reset entered state so we do it
# manually here for this context manager to become reusable.
self._entered = False

return False


class WarningsChecker(WarningsRecorder):
def __init__(self, expected_warning=None, match_expr=None):
def __init__(
self,
expected_warning: Optional[
Union["Type[Warning]", Tuple["Type[Warning]", ...]]
] = None,
match_expr: Optional[Union[str, Pattern]] = None,
) -> None:
super().__init__()

msg = "exceptions must be derived from Warning, not %s"
if isinstance(expected_warning, tuple):
if expected_warning is None:
expected_warning_tup = None
elif isinstance(expected_warning, tuple):
for exc in expected_warning:
if not inspect.isclass(exc):
if not issubclass(exc, Warning):
raise TypeError(msg % type(exc))
elif inspect.isclass(expected_warning):
expected_warning = (expected_warning,)
elif expected_warning is not None:
expected_warning_tup = expected_warning
elif issubclass(expected_warning, Warning):
expected_warning_tup = (expected_warning,)
else:
raise TypeError(msg % type(expected_warning))

self.expected_warning = expected_warning
self.expected_warning = expected_warning_tup
self.match_expr = match_expr

def __exit__(self, *exc_info):
super().__exit__(*exc_info)
def __exit__(
self,
exc_type: Optional["Type[BaseException]"],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> bool:
super().__exit__(exc_type, exc_val, exc_tb)

__tracebackhide__ = True

# only check if we're not currently handling an exception
if all(a is None for a in exc_info):
if exc_type is None and exc_val is None and exc_tb is None:
if self.expected_warning is not None:
if not any(issubclass(r.category, self.expected_warning) for r in self):
__tracebackhide__ = True
Expand All @@ -200,3 +263,4 @@ def __exit__(self, *exc_info):
[each.message for each in self],
)
)
return False

0 comments on commit 2dca68b

Please sign in to comment.