From 6f625758e4e09c52442ecdb4b9355ae07db4ddad Mon Sep 17 00:00:00 2001 From: John Litborn <11260241+jakkdl@users.noreply.github.com> Date: Fri, 17 May 2024 12:44:41 +0200 Subject: [PATCH] Add support to testing.RaisesGroup for catching unwrapped exceptions (#2989) * Add support to testing.RaisesGroup for catching unwrapped exceptions with strict=False * fix type error by adding covariance to typevar * rewrite RaisesGroup docstring * Work around +E typevar issue in docs for _raises_group * Fix docs issue with type property in _ExceptionInfo * split 'strict' into 'flatten_subgroups' and 'allow_unwrapped', fix bug where length check would fail incorrectly sometimes if using flatten_subgroups * add deprecation of strict * bump exceptiongroup to 1.2.1 * fix ^$ matching on exceptiongroups * add test case for nested exceptiongroup + allow_unwrapped * add signature overloads for RaisesGroup to raise type errors when doing incorrect incantations * add pytest.deprecated_call() test * add type tests for narrowing of check argument --------- Co-authored-by: Spencer Brown Co-authored-by: Zac Hatfield-Dodds --- docs/source/conf.py | 18 ++- newsfragments/2989.bugfix.rst | 2 + newsfragments/2989.deprecated.rst | 1 + newsfragments/2989.feature.rst | 1 + src/trio/_core/_run.py | 3 + src/trio/_core/_unbounded_queue.py | 1 + src/trio/_deprecate.py | 22 ++- src/trio/_highlevel_open_tcp_listeners.py | 1 + src/trio/_tests/test_deprecate.py | 36 +++-- src/trio/_tests/test_testing_raisesgroup.py | 132 +++++++++++++++-- src/trio/_tests/type_tests/raisesgroup.py | 130 ++++++++++++++++- src/trio/_threads.py | 1 + src/trio/testing/_raises_group.py | 149 ++++++++++++++++---- test-requirements.in | 4 +- 14 files changed, 445 insertions(+), 56 deletions(-) create mode 100644 newsfragments/2989.bugfix.rst create mode 100644 newsfragments/2989.deprecated.rst create mode 100644 newsfragments/2989.feature.rst diff --git a/docs/source/conf.py b/docs/source/conf.py index 43ce2aa686..ff08adab48 100755 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -21,6 +21,7 @@ import collections.abc import os import sys +import types from typing import TYPE_CHECKING, cast if TYPE_CHECKING: @@ -98,7 +99,7 @@ # https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html#event-autodoc-process-signature def autodoc_process_signature( app: Sphinx, - what: object, + what: str, name: str, obj: object, options: object, @@ -106,6 +107,14 @@ def autodoc_process_signature( return_annotation: str, ) -> tuple[str, str]: """Modify found signatures to fix various issues.""" + if name == "trio.testing._raises_group._ExceptionInfo.type": + # This has the type "type[E]", which gets resolved into the property itself. + # That means Sphinx can't resolve it. Fix the issue by overwriting with a fully-qualified + # name. + assert isinstance(obj, property), obj + assert isinstance(obj.fget, types.FunctionType), obj.fget + assert obj.fget.__annotations__["return"] == "type[E]", obj.fget.__annotations__ + obj.fget.__annotations__["return"] = "type[~trio.testing._raises_group.E]" if signature is not None: signature = signature.replace("~_contextvars.Context", "~contextvars.Context") if name == "trio.lowlevel.RunVar": # Typevar is not useful here. @@ -114,6 +123,13 @@ def autodoc_process_signature( # Strip the type from the union, make it look like = ... signature = signature.replace(" | type[trio._core._local._NoValue]", "") signature = signature.replace("", "...") + if ( + name in ("trio.testing.RaisesGroup", "trio.testing.Matcher") + and "+E" in signature + ): + # This typevar being covariant isn't handled correctly in some cases, strip the + + # and insert the fully-qualified name. + signature = signature.replace("+E", "~trio.testing._raises_group.E") if "DTLS" in name: signature = signature.replace("SSL.Context", "OpenSSL.SSL.Context") # Don't specify PathLike[str] | PathLike[bytes], this is just for humans. diff --git a/newsfragments/2989.bugfix.rst b/newsfragments/2989.bugfix.rst new file mode 100644 index 0000000000..55d97d87c6 --- /dev/null +++ b/newsfragments/2989.bugfix.rst @@ -0,0 +1,2 @@ +Fixed a bug where :class:`trio.testing.RaisesGroup(..., strict=False) ` would check the number of exceptions in the raised `ExceptionGroup` before flattening subgroups, leading to incorrectly failed matches. +It now properly supports end (``$``) regex markers in the ``match`` message, by no longer including " (x sub-exceptions)" in the string it matches against. diff --git a/newsfragments/2989.deprecated.rst b/newsfragments/2989.deprecated.rst new file mode 100644 index 0000000000..bc91844d53 --- /dev/null +++ b/newsfragments/2989.deprecated.rst @@ -0,0 +1 @@ +Deprecated ``strict`` parameter from :class:`trio.testing.RaisesGroup`, previous functionality of ``strict=False`` is now in ``flatten_subgroups=True``. diff --git a/newsfragments/2989.feature.rst b/newsfragments/2989.feature.rst new file mode 100644 index 0000000000..c49fdf36f6 --- /dev/null +++ b/newsfragments/2989.feature.rst @@ -0,0 +1 @@ +:class:`trio.testing.RaisesGroup` can now catch an unwrapped exception with ``unwrapped=True``. This means that the behaviour of :ref:`except* ` can be fully replicated in combination with ``flatten_subgroups=True`` (formerly ``strict=False``). diff --git a/src/trio/_core/_run.py b/src/trio/_core/_run.py index a8b632ce53..89759cc2c2 100644 --- a/src/trio/_core/_run.py +++ b/src/trio/_core/_run.py @@ -1012,6 +1012,7 @@ def open_nursery( "the default value of True and rewrite exception handlers to handle ExceptionGroups. " "See https://trio.readthedocs.io/en/stable/reference-core.html#designing-for-multiple-errors" ), + use_triodeprecationwarning=True, ) if strict_exception_groups is None: @@ -2271,6 +2272,7 @@ def run( "the default value of True and rewrite exception handlers to handle ExceptionGroups. " "See https://trio.readthedocs.io/en/stable/reference-core.html#designing-for-multiple-errors" ), + use_triodeprecationwarning=True, ) __tracebackhide__ = True @@ -2387,6 +2389,7 @@ def my_done_callback(run_outcome): "the default value of True and rewrite exception handlers to handle ExceptionGroups. " "See https://trio.readthedocs.io/en/stable/reference-core.html#designing-for-multiple-errors" ), + use_triodeprecationwarning=True, ) runner = setup_runner( diff --git a/src/trio/_core/_unbounded_queue.py b/src/trio/_core/_unbounded_queue.py index 562d921d05..b9ebe484d7 100644 --- a/src/trio/_core/_unbounded_queue.py +++ b/src/trio/_core/_unbounded_queue.py @@ -66,6 +66,7 @@ class UnboundedQueue(Generic[T]): issue=497, thing="trio.lowlevel.UnboundedQueue", instead="trio.open_memory_channel(math.inf)", + use_triodeprecationwarning=True, ) def __init__(self) -> None: self._lot = _core.ParkingLot() diff --git a/src/trio/_deprecate.py b/src/trio/_deprecate.py index 9a19f219c9..51c51f7378 100644 --- a/src/trio/_deprecate.py +++ b/src/trio/_deprecate.py @@ -58,6 +58,7 @@ def warn_deprecated( issue: int | None, instead: object, stacklevel: int = 2, + use_triodeprecationwarning: bool = False, ) -> None: stacklevel += 1 msg = f"{_stringify(thing)} is deprecated since Trio {version}" @@ -67,20 +68,35 @@ def warn_deprecated( msg += f"; use {_stringify(instead)} instead" if issue is not None: msg += f" ({_url_for_issue(issue)})" - warnings.warn(TrioDeprecationWarning(msg), stacklevel=stacklevel) + if use_triodeprecationwarning: + warning_class: type[Warning] = TrioDeprecationWarning + else: + warning_class = DeprecationWarning + warnings.warn(warning_class(msg), stacklevel=stacklevel) # @deprecated("0.2.0", issue=..., instead=...) # def ... def deprecated( - version: str, *, thing: object = None, issue: int | None, instead: object + version: str, + *, + thing: object = None, + issue: int | None, + instead: object, + use_triodeprecationwarning: bool = False, ) -> Callable[[Callable[ArgsT, RetT]], Callable[ArgsT, RetT]]: def do_wrap(fn: Callable[ArgsT, RetT]) -> Callable[ArgsT, RetT]: nonlocal thing @wraps(fn) def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT: - warn_deprecated(thing, version, instead=instead, issue=issue) + warn_deprecated( + thing, + version, + instead=instead, + issue=issue, + use_triodeprecationwarning=use_triodeprecationwarning, + ) return fn(*args, **kwargs) # If our __module__ or __qualname__ get modified, we want to pick up diff --git a/src/trio/_highlevel_open_tcp_listeners.py b/src/trio/_highlevel_open_tcp_listeners.py index 95c1463394..02df5ef704 100644 --- a/src/trio/_highlevel_open_tcp_listeners.py +++ b/src/trio/_highlevel_open_tcp_listeners.py @@ -56,6 +56,7 @@ def _compute_backlog(backlog: int | None) -> int: version="0.23.0", instead="None", issue=2842, + use_triodeprecationwarning=True, ) if not isinstance(backlog, int) and backlog is not None: raise TypeError(f"backlog must be an int or None, not {backlog!r}") diff --git a/src/trio/_tests/test_deprecate.py b/src/trio/_tests/test_deprecate.py index 48130e66f1..fa5d7cbfef 100644 --- a/src/trio/_tests/test_deprecate.py +++ b/src/trio/_tests/test_deprecate.py @@ -38,7 +38,7 @@ def deprecated_thing() -> None: deprecated_thing() filename, lineno = _here() assert len(recwarn_always) == 1 - got = recwarn_always.pop(TrioDeprecationWarning) + got = recwarn_always.pop(DeprecationWarning) assert isinstance(got.message, Warning) assert "ice is deprecated" in got.message.args[0] assert "Trio 1.2" in got.message.args[0] @@ -54,7 +54,7 @@ def test_warn_deprecated_no_instead_or_issue( # Explicitly no instead or issue warn_deprecated("water", "1.3", issue=None, instead=None) assert len(recwarn_always) == 1 - got = recwarn_always.pop(TrioDeprecationWarning) + got = recwarn_always.pop(DeprecationWarning) assert isinstance(got.message, Warning) assert "water is deprecated" in got.message.args[0] assert "no replacement" in got.message.args[0] @@ -70,7 +70,7 @@ def nested2() -> None: filename, lineno = _here() nested1() - got = recwarn_always.pop(TrioDeprecationWarning) + got = recwarn_always.pop(DeprecationWarning) assert got.filename == filename assert got.lineno == lineno + 1 @@ -85,7 +85,7 @@ def new() -> None: # pragma: no cover def test_warn_deprecated_formatting(recwarn_always: pytest.WarningsRecorder) -> None: warn_deprecated(old, "1.0", issue=1, instead=new) - got = recwarn_always.pop(TrioDeprecationWarning) + got = recwarn_always.pop(DeprecationWarning) assert isinstance(got.message, Warning) assert "test_deprecate.old is deprecated" in got.message.args[0] assert "test_deprecate.new instead" in got.message.args[0] @@ -98,7 +98,7 @@ def deprecated_old() -> int: def test_deprecated_decorator(recwarn_always: pytest.WarningsRecorder) -> None: assert deprecated_old() == 3 - got = recwarn_always.pop(TrioDeprecationWarning) + got = recwarn_always.pop(DeprecationWarning) assert isinstance(got.message, Warning) assert "test_deprecate.deprecated_old is deprecated" in got.message.args[0] assert "1.5" in got.message.args[0] @@ -115,7 +115,7 @@ def method(self) -> int: def test_deprecated_decorator_method(recwarn_always: pytest.WarningsRecorder) -> None: f = Foo() assert f.method() == 7 - got = recwarn_always.pop(TrioDeprecationWarning) + got = recwarn_always.pop(DeprecationWarning) assert isinstance(got.message, Warning) assert "test_deprecate.Foo.method is deprecated" in got.message.args[0] @@ -129,7 +129,7 @@ def test_deprecated_decorator_with_explicit_thing( recwarn_always: pytest.WarningsRecorder, ) -> None: assert deprecated_with_thing() == 72 - got = recwarn_always.pop(TrioDeprecationWarning) + got = recwarn_always.pop(DeprecationWarning) assert isinstance(got.message, Warning) assert "the thing is deprecated" in got.message.args[0] @@ -143,7 +143,7 @@ def new_hotness() -> str: def test_deprecated_alias(recwarn_always: pytest.WarningsRecorder) -> None: assert old_hotness() == "new hotness" - got = recwarn_always.pop(TrioDeprecationWarning) + got = recwarn_always.pop(DeprecationWarning) assert isinstance(got.message, Warning) assert "test_deprecate.old_hotness is deprecated" in got.message.args[0] assert "1.23" in got.message.args[0] @@ -168,7 +168,7 @@ def new_hotness_method(self) -> str: def test_deprecated_alias_method(recwarn_always: pytest.WarningsRecorder) -> None: obj = Alias() assert obj.old_hotness_method() == "new hotness method" - got = recwarn_always.pop(TrioDeprecationWarning) + got = recwarn_always.pop(DeprecationWarning) assert isinstance(got.message, Warning) msg = got.message.args[0] assert "test_deprecate.Alias.old_hotness_method is deprecated" in msg @@ -243,7 +243,7 @@ def test_module_with_deprecations(recwarn_always: pytest.WarningsRecorder) -> No filename, lineno = _here() assert module_with_deprecations.dep1 == "value1" # type: ignore[attr-defined] - got = recwarn_always.pop(TrioDeprecationWarning) + got = recwarn_always.pop(DeprecationWarning) assert isinstance(got.message, Warning) assert got.filename == filename assert got.lineno == lineno + 1 @@ -254,9 +254,23 @@ def test_module_with_deprecations(recwarn_always: pytest.WarningsRecorder) -> No assert "value1 instead" in got.message.args[0] assert module_with_deprecations.dep2 == "value2" # type: ignore[attr-defined] - got = recwarn_always.pop(TrioDeprecationWarning) + got = recwarn_always.pop(DeprecationWarning) assert isinstance(got.message, Warning) assert "instead-string instead" in got.message.args[0] with pytest.raises(AttributeError): module_with_deprecations.asdf # type: ignore[attr-defined] # noqa: B018 # "useless expression" + + +def test_warning_class() -> None: + with pytest.deprecated_call(): + warn_deprecated("foo", "bar", issue=None, instead=None) + + # essentially the same as the above check + with pytest.warns(DeprecationWarning): + warn_deprecated("foo", "bar", issue=None, instead=None) + + with pytest.warns(TrioDeprecationWarning): + warn_deprecated( + "foo", "bar", issue=None, instead=None, use_triodeprecationwarning=True + ) diff --git a/src/trio/_tests/test_testing_raisesgroup.py b/src/trio/_tests/test_testing_raisesgroup.py index 9b6b2a6fb6..1e96d38e52 100644 --- a/src/trio/_tests/test_testing_raisesgroup.py +++ b/src/trio/_tests/test_testing_raisesgroup.py @@ -78,39 +78,108 @@ def test_raises_group() -> None: with RaisesGroup(ValueError, SyntaxError): raise ExceptionGroup("", (ValueError(),)) + +def test_flatten_subgroups() -> None: # loose semantics, as with expect* - with RaisesGroup(ValueError, strict=False): + with RaisesGroup(ValueError, flatten_subgroups=True): raise ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),)) + with RaisesGroup(ValueError, TypeError, flatten_subgroups=True): + raise ExceptionGroup("", (ExceptionGroup("", (ValueError(), TypeError())),)) + with RaisesGroup(ValueError, TypeError, flatten_subgroups=True): + raise ExceptionGroup("", [ExceptionGroup("", [ValueError()]), TypeError()]) + # mixed loose is possible if you want it to be at least N deep - with RaisesGroup(RaisesGroup(ValueError, strict=False)): + with RaisesGroup(RaisesGroup(ValueError, flatten_subgroups=True)): raise ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),)) - with RaisesGroup(RaisesGroup(ValueError, strict=False)): + with RaisesGroup(RaisesGroup(ValueError, flatten_subgroups=True)): raise ExceptionGroup( "", (ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),)),) ) with pytest.raises(ExceptionGroup): - with RaisesGroup(RaisesGroup(ValueError, strict=False)): + with RaisesGroup(RaisesGroup(ValueError, flatten_subgroups=True)): raise ExceptionGroup("", (ValueError(),)) # but not the other way around with pytest.raises( ValueError, - match="^You cannot specify a nested structure inside a RaisesGroup with strict=False$", + match="^You cannot specify a nested structure inside a RaisesGroup with", ): - RaisesGroup(RaisesGroup(ValueError), strict=False) + RaisesGroup(RaisesGroup(ValueError), flatten_subgroups=True) # type: ignore[call-overload] + + +def test_catch_unwrapped_exceptions() -> None: + # Catches lone exceptions with strict=False + # just as except* would + with RaisesGroup(ValueError, allow_unwrapped=True): + raise ValueError - # currently not fully identical in behaviour to expect*, which would also catch an unwrapped exception + # expecting multiple unwrapped exceptions is not possible + with pytest.raises( + ValueError, match="^You cannot specify multiple exceptions with" + ): + RaisesGroup(SyntaxError, ValueError, allow_unwrapped=True) # type: ignore[call-overload] + # if users want one of several exception types they need to use a Matcher + # (which the error message suggests) + with RaisesGroup( + Matcher(check=lambda e: isinstance(e, (SyntaxError, ValueError))), + allow_unwrapped=True, + ): + raise ValueError + + # Unwrapped nested `RaisesGroup` is likely a user error, so we raise an error. + with pytest.raises(ValueError, match="has no effect when expecting"): + RaisesGroup(RaisesGroup(ValueError), allow_unwrapped=True) # type: ignore[call-overload] + + # But it *can* be used to check for nesting level +- 1 if they move it to + # the nested RaisesGroup. Users should probably use `Matcher`s instead though. + with RaisesGroup(RaisesGroup(ValueError, allow_unwrapped=True)): + raise ExceptionGroup("", [ExceptionGroup("", [ValueError()])]) + with RaisesGroup(RaisesGroup(ValueError, allow_unwrapped=True)): + raise ExceptionGroup("", [ValueError()]) + + # with allow_unwrapped=False (default) it will not be caught with pytest.raises(ValueError, match="^value error text$"): - with RaisesGroup(ValueError, strict=False): + with RaisesGroup(ValueError): raise ValueError("value error text") + # allow_unwrapped on it's own won't match against nested groups + with pytest.raises(ExceptionGroup): + with RaisesGroup(ValueError, allow_unwrapped=True): + raise ExceptionGroup("", [ExceptionGroup("", [ValueError()])]) + + # for that you need both allow_unwrapped and flatten_subgroups + with RaisesGroup(ValueError, allow_unwrapped=True, flatten_subgroups=True): + raise ExceptionGroup("", [ExceptionGroup("", [ValueError()])]) + + # code coverage + with pytest.raises(TypeError): + with RaisesGroup(ValueError, allow_unwrapped=True): + raise TypeError + def test_match() -> None: # supports match string with RaisesGroup(ValueError, match="bar"): raise ExceptionGroup("bar", (ValueError(),)) + # now also works with ^$ + with RaisesGroup(ValueError, match="^bar$"): + raise ExceptionGroup("bar", (ValueError(),)) + + # it also includes notes + with RaisesGroup(ValueError, match="my note"): + e = ExceptionGroup("bar", (ValueError(),)) + e.add_note("my note") + raise e + + # and technically you can match it all with ^$ + # but you're probably better off using a Matcher at that point + with RaisesGroup(ValueError, match="^bar\nmy note$"): + e = ExceptionGroup("bar", (ValueError(),)) + e.add_note("my note") + raise e + with pytest.raises(ExceptionGroup): with RaisesGroup(ValueError, match="foo"): raise ExceptionGroup("bar", (ValueError(),)) @@ -125,6 +194,37 @@ def test_check() -> None: raise ExceptionGroup("", (ValueError(),)) +def test_unwrapped_match_check() -> None: + def my_check(e: object) -> bool: # pragma: no cover + return True + + msg = ( + "`allow_unwrapped=True` bypasses the `match` and `check` parameters" + " if the exception is unwrapped. If you intended to match/check the" + " exception you should use a `Matcher` object. If you want to match/check" + " the exceptiongroup when the exception *is* wrapped you need to" + " do e.g. `if isinstance(exc.value, ExceptionGroup):" + " assert RaisesGroup(...).matches(exc.value)` afterwards." + ) + with pytest.raises(ValueError, match=re.escape(msg)): + RaisesGroup(ValueError, allow_unwrapped=True, match="foo") # type: ignore[call-overload] + with pytest.raises(ValueError, match=re.escape(msg)): + RaisesGroup(ValueError, allow_unwrapped=True, check=my_check) # type: ignore[call-overload] + + # Users should instead use a Matcher + rg = RaisesGroup(Matcher(ValueError, match="^foo$"), allow_unwrapped=True) + with rg: + raise ValueError("foo") + with rg: + raise ExceptionGroup("", [ValueError("foo")]) + + # or if they wanted to match/check the group, do a conditional `.matches()` + with RaisesGroup(ValueError, allow_unwrapped=True) as exc: + raise ExceptionGroup("bar", [ValueError("foo")]) + if isinstance(exc.value, ExceptionGroup): # pragma: no branch + assert RaisesGroup(ValueError, match="bar").matches(exc.value) + + def test_RaisesGroup_matches() -> None: rg = RaisesGroup(ValueError) assert not rg.matches(None) @@ -216,6 +316,13 @@ def test_matcher_match() -> None: with RaisesGroup(Matcher(match="foo")): raise ExceptionGroup("", (ValueError("bar"),)) + # check ^$ + with RaisesGroup(Matcher(ValueError, match="^bar$")): + raise ExceptionGroup("", [ValueError("bar")]) + with pytest.raises(ExceptionGroup): + with RaisesGroup(Matcher(ValueError, match="^bar$")): + raise ExceptionGroup("", [ValueError("barr")]) + def test_Matcher_check() -> None: def check_oserror_and_errno_is_5(e: BaseException) -> bool: @@ -260,3 +367,12 @@ def test__ExceptionInfo(monkeypatch: pytest.MonkeyPatch) -> None: assert excinfo.type is ExceptionGroup assert excinfo.value.exceptions[0].args == ("hello",) assert isinstance(excinfo.tb, TracebackType) + + +def test_deprecated_strict() -> None: + """`strict` has been replaced with `flatten_subgroups`""" + # parameter is not included in overloaded signatures at all + with pytest.deprecated_call(): + RaisesGroup(ValueError, strict=False) # type: ignore[call-overload] + with pytest.deprecated_call(): + RaisesGroup(ValueError, strict=True) # type: ignore[call-overload] diff --git a/src/trio/_tests/type_tests/raisesgroup.py b/src/trio/_tests/type_tests/raisesgroup.py index d33f66fbe5..ba88eb09cc 100644 --- a/src/trio/_tests/type_tests/raisesgroup.py +++ b/src/trio/_tests/type_tests/raisesgroup.py @@ -66,12 +66,8 @@ def check_matcher_init() -> None: def check_exc(exc: BaseException) -> bool: return isinstance(exc, ValueError) - def check_filenotfound(exc: FileNotFoundError) -> bool: - return not exc.filename.endswith(".tmp") - # Check various combinations of constructor signatures. - # At least 1 arg must be provided. If exception_type is provided, that narrows - # check's argument. + # At least 1 arg must be provided. Matcher() # type: ignore Matcher(ValueError) Matcher(ValueError, "regex") @@ -79,13 +75,80 @@ def check_filenotfound(exc: FileNotFoundError) -> bool: Matcher(exception_type=ValueError) Matcher(match="regex") Matcher(check=check_exc) - Matcher(check=check_filenotfound) # type: ignore Matcher(ValueError, match="regex") - Matcher(FileNotFoundError, check=check_filenotfound) Matcher(match="regex", check=check_exc) + + def check_filenotfound(exc: FileNotFoundError) -> bool: + return not exc.filename.endswith(".tmp") + + # If exception_type is provided, that narrows the `check` method's argument. + Matcher(FileNotFoundError, check=check_filenotfound) + Matcher(ValueError, check=check_filenotfound) # type: ignore + Matcher(check=check_filenotfound) # type: ignore Matcher(FileNotFoundError, match="regex", check=check_filenotfound) +def raisesgroup_check_type_narrowing() -> None: + """Check type narrowing on the `check` argument to `RaisesGroup`. + All `type: ignore`s are correctly pointing out type errors, except + where otherwise noted. + + + """ + + def handle_exc(e: BaseExceptionGroup[BaseException]) -> bool: + return True + + def handle_kbi(e: BaseExceptionGroup[KeyboardInterrupt]) -> bool: + return True + + def handle_value(e: BaseExceptionGroup[ValueError]) -> bool: + return True + + RaisesGroup(BaseException, check=handle_exc) + RaisesGroup(BaseException, check=handle_kbi) # type: ignore + + RaisesGroup(Exception, check=handle_exc) + RaisesGroup(Exception, check=handle_value) # type: ignore + + RaisesGroup(KeyboardInterrupt, check=handle_exc) + RaisesGroup(KeyboardInterrupt, check=handle_kbi) + RaisesGroup(KeyboardInterrupt, check=handle_value) # type: ignore + + RaisesGroup(ValueError, check=handle_exc) + RaisesGroup(ValueError, check=handle_kbi) # type: ignore + RaisesGroup(ValueError, check=handle_value) + + RaisesGroup(ValueError, KeyboardInterrupt, check=handle_exc) + RaisesGroup(ValueError, KeyboardInterrupt, check=handle_kbi) # type: ignore + RaisesGroup(ValueError, KeyboardInterrupt, check=handle_value) # type: ignore + + +def raisesgroup_narrow_baseexceptiongroup() -> None: + """Check type narrowing specifically for the container exceptiongroup. + This is not currently working, and after playing around with it for a bit + I think the only way is to introduce a subclass `NonBaseRaisesGroup`, and overload + `__new__` in Raisesgroup to return the subclass when exceptions are non-base. + (or make current class BaseRaisesGroup and introduce RaisesGroup for non-base) + I encountered problems trying to type this though, see + https://github.com/python/mypy/issues/17251 + That is probably possible to work around by entirely using `__new__` instead of + `__init__`, but........ ugh. + """ + + def handle_group(e: ExceptionGroup[Exception]) -> bool: + return True + + def handle_group_value(e: ExceptionGroup[ValueError]) -> bool: + return True + + # should work, but BaseExceptionGroup does not get narrowed to ExceptionGroup + RaisesGroup(ValueError, check=handle_group_value) # type: ignore + + # should work, but BaseExceptionGroup does not get narrowed to ExceptionGroup + RaisesGroup(Exception, check=handle_group) # type: ignore + + def check_matcher_transparent() -> None: with RaisesGroup(Matcher(ValueError)) as e: ... @@ -126,3 +189,56 @@ def check_nested_raisesgroups_matches() -> None: # has the same problems as check_nested_raisesgroups_contextmanager if RaisesGroup(RaisesGroup(ValueError)).matches(exc): assert_type(exc, BaseExceptionGroup[RaisesGroup[ValueError]]) + + +def check_multiple_exceptions_1() -> None: + a = RaisesGroup(ValueError, ValueError) + b = RaisesGroup(Matcher(ValueError), Matcher(ValueError)) + c = RaisesGroup(ValueError, Matcher(ValueError)) + + d: BaseExceptionGroup[ValueError] + d = a + d = b + d = c + assert d + + +def check_multiple_exceptions_2() -> None: + # This previously failed due to lack of covariance in the TypeVar + a = RaisesGroup(Matcher(ValueError), Matcher(TypeError)) + b = RaisesGroup(Matcher(ValueError), TypeError) + c = RaisesGroup(ValueError, TypeError) + + d: BaseExceptionGroup[Exception] + d = a + d = b + d = c + assert d + + +def check_raisesgroup_overloads() -> None: + # allow_unwrapped=True does not allow: + # multiple exceptions + RaisesGroup(ValueError, TypeError, allow_unwrapped=True) # type: ignore + # nested RaisesGroup + RaisesGroup(RaisesGroup(ValueError), allow_unwrapped=True) # type: ignore + # specifying match + RaisesGroup(ValueError, match="foo", allow_unwrapped=True) # type: ignore + # specifying check + RaisesGroup(ValueError, check=bool, allow_unwrapped=True) # type: ignore + # allowed variants + RaisesGroup(ValueError, allow_unwrapped=True) + RaisesGroup(ValueError, allow_unwrapped=True, flatten_subgroups=True) + RaisesGroup(Matcher(ValueError), allow_unwrapped=True) + + # flatten_subgroups=True does not allow nested RaisesGroup + RaisesGroup(RaisesGroup(ValueError), flatten_subgroups=True) # type: ignore + # but rest is plenty fine + RaisesGroup(ValueError, TypeError, flatten_subgroups=True) + RaisesGroup(ValueError, match="foo", flatten_subgroups=True) + RaisesGroup(ValueError, check=bool, flatten_subgroups=True) + RaisesGroup(ValueError, flatten_subgroups=True) + RaisesGroup(Matcher(ValueError), flatten_subgroups=True) + + # if they're both false we can of course specify nested raisesgroup + RaisesGroup(RaisesGroup(ValueError)) diff --git a/src/trio/_threads.py b/src/trio/_threads.py index b002a58552..df595c1c9c 100644 --- a/src/trio/_threads.py +++ b/src/trio/_threads.py @@ -367,6 +367,7 @@ async def to_thread_run_sync( # type: ignore[misc] "0.23.0", issue=2841, instead="`abandon_on_cancel=`", + use_triodeprecationwarning=True, ) abandon_on_cancel = cancellable # raise early if abandon_on_cancel.__bool__ raises diff --git a/src/trio/testing/_raises_group.py b/src/trio/testing/_raises_group.py index b7d0db0486..16bde651f4 100644 --- a/src/trio/testing/_raises_group.py +++ b/src/trio/testing/_raises_group.py @@ -7,13 +7,15 @@ Callable, ContextManager, Generic, - Iterable, + Literal, Pattern, + Sequence, TypeVar, cast, overload, ) +from trio._deprecate import warn_deprecated from trio._util import final if TYPE_CHECKING: @@ -29,7 +31,7 @@ if sys.version_info < (3, 11): from exceptiongroup import BaseExceptionGroup -E = TypeVar("E", bound=BaseException) +E = TypeVar("E", bound=BaseException, covariant=True) @final @@ -122,7 +124,7 @@ def getrepr( def _stringify_exception(exc: BaseException) -> str: return "\n".join( [ - str(exc), + getattr(exc, "message", str(exc)), *getattr(exc, "__notes__", []), ] ) @@ -263,8 +265,20 @@ class RaisesGroup(ContextManager[ExceptionInfo[BaseExceptionGroup[E]]], SuperCla This works similar to ``pytest.raises``, and a version of it will hopefully be added upstream, after which this can be deprecated and removed. See https://github.com/pytest-dev/pytest/issues/11538 - This differs from :ref:`except* ` in that all specified exceptions must be present, *and no others*. It will similarly not catch exceptions *not* wrapped in an exceptiongroup. - If you don't care for the nesting level of the exceptions you can pass ``strict=False``. + The catching behaviour differs from :ref:`except* ` in multiple different ways, being much stricter by default. By using ``allow_unwrapped=True`` and ``flatten_subgroups=True`` you can match ``except*`` fully when expecting a single exception. + + #. All specified exceptions must be present, *and no others*. + + * If you expect a variable number of exceptions you need to use ``pytest.raises(ExceptionGroup)`` and manually check the contained exceptions. Consider making use of :func:`Matcher.matches`. + + #. It will only catch exceptions wrapped in an exceptiongroup by default. + + * With ``allow_unwrapped=True`` you can specify a single expected exception or `Matcher` and it will match the exception even if it is not inside an `ExceptionGroup`. If you expect one of several different exception types you need to use a `Matcher` object. + + #. By default it cares about the full structure with nested `ExceptionGroup`'s. You can specify nested `ExceptionGroup`'s by passing `RaisesGroup` objects as expected exceptions. + + * With ``flatten_subgroups=True`` it will "flatten" the raised `ExceptionGroup`, extracting all exceptions inside any nested :class:`ExceptionGroup`, before matching. + It currently does not care about the order of the exceptions, so ``RaisesGroups(ValueError, TypeError)`` is equivalent to ``RaisesGroups(TypeError, ValueError)``. This class is not as polished as ``pytest.raises``, and is currently not as helpful in e.g. printing diffs when strings don't match, suggesting you use ``re.escape``, etc. @@ -280,14 +294,19 @@ class RaisesGroup(ContextManager[ExceptionInfo[BaseExceptionGroup[E]]], SuperCla with RaisesGroups(RaisesGroups(ValueError)): raise ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),)) - with RaisesGroups(ValueError, strict=False): + # flatten_subgroups + with RaisesGroups(ValueError, flatten_subgroups=True): raise ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),)) + # allow_unwrapped + with RaisesGroups(ValueError, allow_unwrapped=True): + raise ValueError + `RaisesGroup.matches` can also be used directly to check a standalone exception group. - This class is also not perfectly smart, e.g. this will likely fail currently:: + The matching algorithm is greedy, which means cases such as this may fail:: with RaisesGroups(ValueError, Matcher(ValueError, match="hello")): raise ExceptionGroup("", (ValueError("hello"), ValueError("goodbye"))) @@ -303,32 +322,108 @@ class RaisesGroup(ContextManager[ExceptionInfo[BaseExceptionGroup[E]]], SuperCla def __new__(cls, *args: object, **kwargs: object) -> RaisesGroup[E]: ... + # allow_unwrapped=True requires: singular exception, exception not being + # RaisesGroup instance, match is None, check is None + @overload + def __init__( + self, + exception: type[E] | Matcher[E], + *, + allow_unwrapped: Literal[True], + flatten_subgroups: bool = False, + match: None = None, + check: None = None, + ): ... + + # flatten_subgroups = True also requires no nested RaisesGroup + @overload + def __init__( + self, + exception: type[E] | Matcher[E], + *other_exceptions: type[E] | Matcher[E], + allow_unwrapped: Literal[False] = False, + flatten_subgroups: Literal[True], + match: str | Pattern[str] | None = None, + check: Callable[[BaseExceptionGroup[E]], bool] | None = None, + ): ... + + @overload def __init__( self, exception: type[E] | Matcher[E] | E, *other_exceptions: type[E] | Matcher[E] | E, - strict: bool = True, + allow_unwrapped: Literal[False] = False, + flatten_subgroups: Literal[False] = False, match: str | Pattern[str] | None = None, check: Callable[[BaseExceptionGroup[E]], bool] | None = None, + ): ... + + def __init__( + self, + exception: type[E] | Matcher[E] | E, + *other_exceptions: type[E] | Matcher[E] | E, + allow_unwrapped: bool = False, + flatten_subgroups: bool = False, + match: str | Pattern[str] | None = None, + check: Callable[[BaseExceptionGroup[E]], bool] | None = None, + strict: None = None, ): self.expected_exceptions: tuple[type[E] | Matcher[E] | E, ...] = ( exception, *other_exceptions, ) - self.strict = strict + self.flatten_subgroups: bool = flatten_subgroups + self.allow_unwrapped = allow_unwrapped self.match_expr = match self.check = check self.is_baseexceptiongroup = False + if strict is not None: + warn_deprecated( + "The `strict` parameter", + "0.25.1", + issue=2989, + instead="flatten_subgroups=True (for strict=False}", + ) + self.flatten_subgroups = not strict + + if allow_unwrapped and other_exceptions: + raise ValueError( + "You cannot specify multiple exceptions with `allow_unwrapped=True.`" + " If you want to match one of multiple possible exceptions you should" + " use a `Matcher`." + " E.g. `Matcher(check=lambda e: isinstance(e, (...)))`" + ) + if allow_unwrapped and isinstance(exception, RaisesGroup): + raise ValueError( + "`allow_unwrapped=True` has no effect when expecting a `RaisesGroup`." + " You might want it in the expected `RaisesGroup`, or" + " `flatten_subgroups=True` if you don't care about the structure." + ) + if allow_unwrapped and (match is not None or check is not None): + raise ValueError( + "`allow_unwrapped=True` bypasses the `match` and `check` parameters" + " if the exception is unwrapped. If you intended to match/check the" + " exception you should use a `Matcher` object. If you want to match/check" + " the exceptiongroup when the exception *is* wrapped you need to" + " do e.g. `if isinstance(exc.value, ExceptionGroup):" + " assert RaisesGroup(...).matches(exc.value)` afterwards." + ) + + # verify `expected_exceptions` and set `self.is_baseexceptiongroup` for exc in self.expected_exceptions: if isinstance(exc, RaisesGroup): - if not strict: + if self.flatten_subgroups: raise ValueError( "You cannot specify a nested structure inside a RaisesGroup with" - " strict=False" + " `flatten_subgroups=True`. The parameter will flatten subgroups" + " in the raised exceptiongroup before matching, which would never" + " match a nested structure." ) self.is_baseexceptiongroup |= exc.is_baseexceptiongroup elif isinstance(exc, Matcher): + # The Matcher could match BaseExceptions through the other arguments + # but `self.is_baseexceptiongroup` is only used for printing. if exc.exception_type is None: continue # Matcher __init__ assures it's a subclass of BaseException @@ -348,9 +443,9 @@ def __enter__(self) -> ExceptionInfo[BaseExceptionGroup[E]]: return self.excinfo def _unroll_exceptions( - self, exceptions: Iterable[BaseException] - ) -> Iterable[BaseException]: - """Used in non-strict mode.""" + self, exceptions: Sequence[BaseException] + ) -> Sequence[BaseException]: + """Used if `flatten_subgroups=True`.""" res: list[BaseException] = [] for exc in exceptions: if isinstance(exc, BaseExceptionGroup): @@ -383,32 +478,38 @@ def matches( # maybe have a list of strings logging failed matches, that __exit__ can # recursively step through and print on a failing match. if not isinstance(exc_val, BaseExceptionGroup): + if self.allow_unwrapped: + exp_exc = self.expected_exceptions[0] + if isinstance(exp_exc, Matcher) and exp_exc.matches(exc_val): + return True + if isinstance(exp_exc, type) and isinstance(exc_val, exp_exc): + return True return False - if len(exc_val.exceptions) != len(self.expected_exceptions): - return False + if self.match_expr is not None and not re.search( self.match_expr, _stringify_exception(exc_val) ): return False if self.check is not None and not self.check(exc_val): return False + remaining_exceptions = list(self.expected_exceptions) - actual_exceptions: Iterable[BaseException] = exc_val.exceptions - if not self.strict: + actual_exceptions: Sequence[BaseException] = exc_val.exceptions + if self.flatten_subgroups: actual_exceptions = self._unroll_exceptions(actual_exceptions) + # important to check the length *after* flattening subgroups + if len(actual_exceptions) != len(self.expected_exceptions): + return False + # it should be possible to get RaisesGroup.matches typed so as not to - # need these type: ignores, but I'm not sure that's possible while also having it + # need type: ignore, but I'm not sure that's possible while also having it # transparent for the end user. for e in actual_exceptions: for rem_e in remaining_exceptions: if ( (isinstance(rem_e, type) and isinstance(e, rem_e)) - or ( - isinstance(e, BaseExceptionGroup) - and isinstance(rem_e, RaisesGroup) - and rem_e.matches(e) - ) + or (isinstance(rem_e, RaisesGroup) and rem_e.matches(e)) or (isinstance(rem_e, Matcher) and rem_e.matches(e)) ): remaining_exceptions.remove(rem_e) # type: ignore[arg-type] diff --git a/test-requirements.in b/test-requirements.in index ed7055d58a..e36ef18e87 100644 --- a/test-requirements.in +++ b/test-requirements.in @@ -33,5 +33,5 @@ sortedcontainers idna outcome sniffio -# 1.2.0 ships monkeypatching for apport excepthook -exceptiongroup >= 1.2.0; python_version < "3.11" +# 1.2.1 fixes types +exceptiongroup >= 1.2.1; python_version < "3.11"