diff --git a/src/_pytest/unittest.py b/src/_pytest/unittest.py index 7498f1b0002..cbf5d86e97d 100644 --- a/src/_pytest/unittest.py +++ b/src/_pytest/unittest.py @@ -40,7 +40,170 @@ if sys.version_info[:2] < (3, 11): + from collections import UserList + from functools import partial + from typing import cast + from typing import overload + from typing import TypeVar + import unittest.case + from exceptiongroup import ExceptionGroup + from typing_extensions import Self + from typing_extensions import SupportsIndex + + T = TypeVar("T") + + class ObservableList(UserList[T]): + def __init__( + self, + init: Iterable[T] | None = None, + on_change: Callable[[], None] = lambda: None, + ) -> None: + super().__init__([] if init is None else init) + self._on_change = on_change + + def _changed(self) -> None: + self._on_change() + + @overload + def __setitem__(self, i: SupportsIndex, item: T) -> None: + pass + + @overload + def __setitem__(self, i: slice, item: Iterable[T]) -> None: + pass + + def __setitem__(self, i, item) -> None: + super().__setitem__(i, item) + self._changed() + + @overload + def __delitem__(self, i: int) -> None: + pass + + @overload + def __delitem__(self, i: SupportsIndex | slice) -> None: + pass + + def __delitem__(self, i) -> None: + super().__delitem__(i) + self._changed() + + def __iadd__(self, other: Iterable[T]) -> Self: + result = super().__iadd__(other) + self._changed() + return result + + def __imul__(self, other: int) -> Self: + result = super().__imul__(other) + self._changed() + return result + + def append(self, x: T) -> None: + super().append(x) + self._changed() + + def insert(self, i: int, item: T) -> None: + super().insert(i, item) + self._changed() + + def pop(self, i: int = -1) -> T: + result = super().pop(i) + self._changed() + return result + + def remove(self, item: T) -> None: + super().remove(item) + self._changed() + + def clear(self) -> None: + super().clear() + self._changed() + + def reverse(self) -> None: + super().reverse() + self._changed() + + def sort(self, /, *args, **kwargs) -> None: + super().sort(*args, **kwargs) + self._changed() + + def extend(self, it: Iterable[T]) -> None: + super().extend(it) + self._changed() + + class _Outcome(unittest.case._Outcome): # type: ignore[name-defined] + def __init__(self, *args: object, **kwargs: object) -> None: + super().__init__(*args, **kwargs) + self._was_skipped_list_modified = True + self._was_errors_list_modified = True + self._last_id_of_skipped = id(self.skipped) + self._last_id_of_errors = id(self.errors) + self._non_subtest_skip: list[tuple[object, object]] = [] + self._subtest_errors: list[tuple[object, object]] = [] + + def _on_change(self, key: str) -> None: + if key == "skipped": + self._was_skipped_list_modified = True + elif key == "errors": + self._was_errors_list_modified = True + else: + raise RuntimeError("unreachable") + + @property + def skipped(self) -> list[tuple[object, object]]: + return cast(list[tuple[object, object]], self.__dict__["skipped"]) + + @skipped.setter + def skipped(self, value: list[tuple[object, object]]) -> None: + _on_change = partial(self._on_change, "skipped") + self.__dict__["skipped"] = ObservableList(value, _on_change) + self._last_id_of_skipped = id(self.skipped) + self._was_skipped_list_modified = True + + @property + def errors(self) -> list[tuple[object, object]]: + return cast(list[tuple[object, object]], self.__dict__["errors"]) + + @errors.setter + def errors(self, value: list[tuple[object, object]]) -> None: + _on_change = partial(self._on_change, "errors") + self.__dict__["errors"] = ObservableList(value, _on_change) + self._last_id_of_errors = id(self.errors) + self._was_errors_list_modified = True + + def non_subtest_skip(self) -> list[tuple[object, object]]: + from unittest.case import _SubTest # type: ignore[attr-defined] + + if id(self.skipped) != self._last_id_of_skipped: + self.skipped = self.skipped # calls the setter + + if self._was_skipped_list_modified: + self._non_subtest_skip = [ + (x, y) for x, y in self.skipped if not isinstance(x, _SubTest) + ] + self._was_skipped_list_modified = False + + return self._non_subtest_skip + + def subtest_errors(self) -> list[tuple[object, object]]: + from unittest.case import _SubTest # type: ignore[attr-defined] + + if id(self.errors) != self._last_id_of_errors: + self.errors = self.errors # calls the setter + + if self._was_errors_list_modified: + self._subtest_errors = [ + (x, y) + for x, y in self.errors + if isinstance(x, _SubTest) and y is not None + ] + self._was_errors_list_modified = False + + return self._subtest_errors + + unittest.case._Outcome = _Outcome # type: ignore[attr-defined] + if TYPE_CHECKING: from types import TracebackType @@ -313,11 +476,7 @@ def add_skip() -> None: # We also need to check if `self.instance._outcome` is `None` (this happens if the test # class/method is decorated with `unittest.skip`, see pytest-dev/pytest-subtests#173). if sys.version_info < (3, 11) and self.instance._outcome is not None: - subtest_errors = [ - x - for x, y in self.instance._outcome.errors - if isinstance(x, _SubTest) and y is not None - ] + subtest_errors = self.instance._outcome.subtest_errors() if len(subtest_errors) == 0: add_skip() else: @@ -443,18 +602,8 @@ def addSubTest( # For python < 3.11: add non-subtest skips once all subtest failures are processed by # `_addSubTest`. if sys.version_info < (3, 11): - from unittest.case import _SubTest # type: ignore[attr-defined] - - non_subtest_skip = [ - (x, y) - for x, y in self.instance._outcome.skipped - if not isinstance(x, _SubTest) - ] - subtest_errors = [ - (x, y) - for x, y in self.instance._outcome.errors - if isinstance(x, _SubTest) and y is not None - ] + non_subtest_skip = self.instance._outcome.non_subtest_skip() + subtest_errors = self.instance._outcome.subtest_errors() # Check if we have non-subtest skips: if there are also sub failures, non-subtest skips are not treated in # `_addSubTest` and have to be added using `add_skip` after all subtest failures are processed. if len(non_subtest_skip) > 0 and len(subtest_errors) > 0: diff --git a/testing/test_subtests.py b/testing/test_subtests.py index 6849df53622..15d70b93b8e 100644 --- a/testing/test_subtests.py +++ b/testing/test_subtests.py @@ -460,6 +460,30 @@ def test_zaz(self): ] ) + def test_passes_many_subtests( + self, pytester: pytest.Pytester, monkeypatch: pytest.MonkeyPatch + ) -> None: + # see https://github.com/pytest-dev/pytest/issues/13965 + + monkeypatch.setenv("COLUMNS", "120") + pytester.makepyfile( + """ + from unittest import TestCase + + class T(TestCase): + def test_foo(self): + for _ in range(1000 * 100): + with self.subTest(): + pass + """ + ) + result = pytester.runpytest() + result.stdout.fnmatch_lines( + [ + "* 1 passed in *", + ] + ) + def test_skip( self, pytester: pytest.Pytester,