Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 166 additions & 17 deletions src/_pytest/unittest.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,170 @@


if sys.version_info[:2] < (3, 11):
from collections import UserList
from functools import partial
from typing import cast
from typing import overload
from typing import TypeVar
import unittest.case

from exceptiongroup import ExceptionGroup
from typing_extensions import Self
from typing_extensions import SupportsIndex

T = TypeVar("T")

class ObservableList(UserList[T]):
def __init__(
self,
init: Iterable[T] | None = None,
on_change: Callable[[], None] = lambda: None,
) -> None:
super().__init__([] if init is None else init)
self._on_change = on_change

def _changed(self) -> None:
self._on_change()

@overload
def __setitem__(self, i: SupportsIndex, item: T) -> None:
pass

@overload
def __setitem__(self, i: slice, item: Iterable[T]) -> None:
pass

def __setitem__(self, i, item) -> None:
super().__setitem__(i, item)
self._changed()

@overload
def __delitem__(self, i: int) -> None:
pass

@overload
def __delitem__(self, i: SupportsIndex | slice) -> None:
pass

def __delitem__(self, i) -> None:
super().__delitem__(i)
self._changed()

def __iadd__(self, other: Iterable[T]) -> Self:
result = super().__iadd__(other)
self._changed()
return result

def __imul__(self, other: int) -> Self:
result = super().__imul__(other)
self._changed()
return result

def append(self, x: T) -> None:
super().append(x)
self._changed()

def insert(self, i: int, item: T) -> None:
super().insert(i, item)
self._changed()

def pop(self, i: int = -1) -> T:
result = super().pop(i)
self._changed()
return result

def remove(self, item: T) -> None:
super().remove(item)
self._changed()

def clear(self) -> None:
super().clear()
self._changed()

def reverse(self) -> None:
super().reverse()
self._changed()

def sort(self, /, *args, **kwargs) -> None:
super().sort(*args, **kwargs)
self._changed()

def extend(self, it: Iterable[T]) -> None:
super().extend(it)
self._changed()

class _Outcome(unittest.case._Outcome): # type: ignore[name-defined]
def __init__(self, *args: object, **kwargs: object) -> None:
super().__init__(*args, **kwargs)
self._was_skipped_list_modified = True
self._was_errors_list_modified = True
self._last_id_of_skipped = id(self.skipped)
self._last_id_of_errors = id(self.errors)
self._non_subtest_skip: list[tuple[object, object]] = []
self._subtest_errors: list[tuple[object, object]] = []

def _on_change(self, key: str) -> None:
if key == "skipped":
self._was_skipped_list_modified = True
elif key == "errors":
self._was_errors_list_modified = True
else:
raise RuntimeError("unreachable")

@property
def skipped(self) -> list[tuple[object, object]]:
return cast(list[tuple[object, object]], self.__dict__["skipped"])

@skipped.setter
def skipped(self, value: list[tuple[object, object]]) -> None:
_on_change = partial(self._on_change, "skipped")
self.__dict__["skipped"] = ObservableList(value, _on_change)
self._last_id_of_skipped = id(self.skipped)
self._was_skipped_list_modified = True

@property
def errors(self) -> list[tuple[object, object]]:
return cast(list[tuple[object, object]], self.__dict__["errors"])

@errors.setter
def errors(self, value: list[tuple[object, object]]) -> None:
_on_change = partial(self._on_change, "errors")
self.__dict__["errors"] = ObservableList(value, _on_change)
self._last_id_of_errors = id(self.errors)
self._was_errors_list_modified = True

def non_subtest_skip(self) -> list[tuple[object, object]]:
from unittest.case import _SubTest # type: ignore[attr-defined]

if id(self.skipped) != self._last_id_of_skipped:
self.skipped = self.skipped # calls the setter

if self._was_skipped_list_modified:
self._non_subtest_skip = [
(x, y) for x, y in self.skipped if not isinstance(x, _SubTest)
]
self._was_skipped_list_modified = False

return self._non_subtest_skip

def subtest_errors(self) -> list[tuple[object, object]]:
from unittest.case import _SubTest # type: ignore[attr-defined]

if id(self.errors) != self._last_id_of_errors:
self.errors = self.errors # calls the setter

if self._was_errors_list_modified:
self._subtest_errors = [
(x, y)
for x, y in self.errors
if isinstance(x, _SubTest) and y is not None
]
self._was_errors_list_modified = False

return self._subtest_errors

unittest.case._Outcome = _Outcome # type: ignore[attr-defined]


if TYPE_CHECKING:
from types import TracebackType
Expand Down Expand Up @@ -313,11 +476,7 @@ def add_skip() -> None:
# We also need to check if `self.instance._outcome` is `None` (this happens if the test
# class/method is decorated with `unittest.skip`, see pytest-dev/pytest-subtests#173).
if sys.version_info < (3, 11) and self.instance._outcome is not None:
subtest_errors = [
x
for x, y in self.instance._outcome.errors
if isinstance(x, _SubTest) and y is not None
]
subtest_errors = self.instance._outcome.subtest_errors()
if len(subtest_errors) == 0:
add_skip()
else:
Expand Down Expand Up @@ -443,18 +602,8 @@ def addSubTest(

# For python < 3.11: add non-subtest skips once all subtest failures are processed by # `_addSubTest`.
if sys.version_info < (3, 11):
from unittest.case import _SubTest # type: ignore[attr-defined]

non_subtest_skip = [
(x, y)
for x, y in self.instance._outcome.skipped
if not isinstance(x, _SubTest)
]
subtest_errors = [
(x, y)
for x, y in self.instance._outcome.errors
if isinstance(x, _SubTest) and y is not None
]
non_subtest_skip = self.instance._outcome.non_subtest_skip()
subtest_errors = self.instance._outcome.subtest_errors()
# Check if we have non-subtest skips: if there are also sub failures, non-subtest skips are not treated in
# `_addSubTest` and have to be added using `add_skip` after all subtest failures are processed.
if len(non_subtest_skip) > 0 and len(subtest_errors) > 0:
Expand Down
24 changes: 24 additions & 0 deletions testing/test_subtests.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,30 @@ def test_zaz(self):
]
)

def test_passes_many_subtests(
self, pytester: pytest.Pytester, monkeypatch: pytest.MonkeyPatch
) -> None:
# see https://github.com/pytest-dev/pytest/issues/13965

monkeypatch.setenv("COLUMNS", "120")
pytester.makepyfile(
"""
from unittest import TestCase

class T(TestCase):
def test_foo(self):
for _ in range(1000 * 100):
with self.subTest():
pass
"""
)
result = pytester.runpytest()
result.stdout.fnmatch_lines(
[
"* 1 passed in *",
]
)

def test_skip(
self,
pytester: pytest.Pytester,
Expand Down