diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 6ecdab3..fbf0c52 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,6 +1,11 @@ Releases ======== +UNRELEASED +---------- + +* `#524 `_: Added ``spy_return_iter`` to ``mocker.spy``, which contains a duplicate of the return value of the spied method if it is an ``Iterator``. + 3.14.1 (2025-05-26) ------------------- diff --git a/docs/usage.rst b/docs/usage.rst index 339746a..587fcb3 100644 --- a/docs/usage.rst +++ b/docs/usage.rst @@ -81,6 +81,7 @@ are available (like ``assert_called_once_with`` or ``call_count`` in the example In addition, spy objects contain two extra attributes: * ``spy_return``: contains the last returned value of the spied function. +* ``spy_return_iter``: contains a duplicate of the last returned value of the spied function if the value was an iterator. Uses `tee `__) to duplicate the iterator. * ``spy_return_list``: contains a list of all returned values of the spied function (new in ``3.13``). * ``spy_exception``: contain the last exception value raised by the spied function/method when it was last called, or ``None`` if no exception was raised. diff --git a/src/pytest_mock/plugin.py b/src/pytest_mock/plugin.py index 50dc06a..e22abc3 100644 --- a/src/pytest_mock/plugin.py +++ b/src/pytest_mock/plugin.py @@ -1,6 +1,7 @@ import builtins import functools import inspect +import itertools import unittest.mock import warnings from dataclasses import dataclass @@ -137,6 +138,8 @@ def resetall( # NOTE: The mock may be a dictionary if hasattr(mock_item.mock, "spy_return_list"): mock_item.mock.spy_return_list = [] + if hasattr(mock_item.mock, "spy_return_iter"): + mock_item.mock.spy_return_iter = None if isinstance(mock_item.mock, supports_reset_mock_with_args): mock_item.mock.reset_mock( return_value=return_value, side_effect=side_effect @@ -178,6 +181,12 @@ def wrapper(*args, **kwargs): spy_obj.spy_exception = e raise else: + if isinstance(r, Iterator): + r, duplicated_iterator = itertools.tee(r, 2) + spy_obj.spy_return_iter = duplicated_iterator + else: + spy_obj.spy_return_iter = None + spy_obj.spy_return = r spy_obj.spy_return_list.append(r) return r @@ -204,6 +213,7 @@ async def async_wrapper(*args, **kwargs): spy_obj = self.patch.object(obj, name, side_effect=wrapped, autospec=autospec) spy_obj.spy_return = None + spy_obj.spy_return_iter = None spy_obj.spy_return_list = [] spy_obj.spy_exception = None return spy_obj diff --git a/tests/test_pytest_mock.py b/tests/test_pytest_mock.py index 14187c2..174a362 100644 --- a/tests/test_pytest_mock.py +++ b/tests/test_pytest_mock.py @@ -7,6 +7,8 @@ from typing import Any from typing import Callable from typing import Generator +from typing import Iterable +from typing import Iterator from typing import Tuple from typing import Type from unittest.mock import AsyncMock @@ -265,12 +267,14 @@ def bar(self, arg): assert other.bar(arg=10) == 20 foo.bar.assert_called_once_with(arg=10) # type:ignore[attr-defined] assert foo.bar.spy_return == 20 # type:ignore[attr-defined] + assert foo.bar.spy_return_iter is None # type:ignore[attr-defined] assert foo.bar.spy_return_list == [20] # type:ignore[attr-defined] spy.assert_called_once_with(arg=10) assert spy.spy_return == 20 assert foo.bar(arg=11) == 22 assert foo.bar(arg=12) == 24 assert spy.spy_return == 24 + assert spy.spy_return_iter is None assert spy.spy_return_list == [20, 22, 24] @@ -349,11 +353,13 @@ def bar(self, x): spy = mocker.spy(Foo, "bar") assert spy.spy_return is None + assert spy.spy_return_iter is None assert spy.spy_return_list == [] assert spy.spy_exception is None Foo().bar(10) assert spy.spy_return == 30 + assert spy.spy_return_iter is None assert spy.spy_return_list == [30] assert spy.spy_exception is None @@ -363,11 +369,13 @@ def bar(self, x): with pytest.raises(ValueError): Foo().bar(0) assert spy.spy_return is None + assert spy.spy_return_iter is None assert spy.spy_return_list == [] assert str(spy.spy_exception) == "invalid x" Foo().bar(15) assert spy.spy_return == 45 + assert spy.spy_return_iter is None assert spy.spy_return_list == [45] assert spy.spy_exception is None @@ -404,6 +412,7 @@ class Foo(Base): calls = [mocker.call(foo, arg=10), mocker.call(other, arg=10)] assert spy.call_args_list == calls assert spy.spy_return == 20 + assert spy.spy_return_iter is None assert spy.spy_return_list == [20, 20] @@ -418,9 +427,11 @@ def bar(cls, arg): assert Foo.bar(arg=10) == 20 Foo.bar.assert_called_once_with(arg=10) # type:ignore[attr-defined] assert Foo.bar.spy_return == 20 # type:ignore[attr-defined] + assert Foo.bar.spy_return_iter is None # type:ignore[attr-defined] assert Foo.bar.spy_return_list == [20] # type:ignore[attr-defined] spy.assert_called_once_with(arg=10) assert spy.spy_return == 20 + assert spy.spy_return_iter is None assert spy.spy_return_list == [20] @@ -438,9 +449,11 @@ class Foo(Base): assert Foo.bar(arg=10) == 20 Foo.bar.assert_called_once_with(arg=10) # type:ignore[attr-defined] assert Foo.bar.spy_return == 20 # type:ignore[attr-defined] + assert Foo.bar.spy_return_iter is None # type:ignore[attr-defined] assert Foo.bar.spy_return_list == [20] # type:ignore[attr-defined] spy.assert_called_once_with(arg=10) assert spy.spy_return == 20 + assert spy.spy_return_iter is None assert spy.spy_return_list == [20] @@ -460,9 +473,11 @@ def bar(cls, arg): assert Foo.bar(arg=10) == 20 Foo.bar.assert_called_once_with(arg=10) # type:ignore[attr-defined] assert Foo.bar.spy_return == 20 # type:ignore[attr-defined] + assert Foo.bar.spy_return_iter is None # type:ignore[attr-defined] assert Foo.bar.spy_return_list == [20] # type:ignore[attr-defined] spy.assert_called_once_with(arg=10) assert spy.spy_return == 20 + assert spy.spy_return_iter is None assert spy.spy_return_list == [20] @@ -477,9 +492,11 @@ def bar(arg): assert Foo.bar(arg=10) == 20 Foo.bar.assert_called_once_with(arg=10) # type:ignore[attr-defined] assert Foo.bar.spy_return == 20 # type:ignore[attr-defined] + assert Foo.bar.spy_return_iter is None # type:ignore[attr-defined] assert Foo.bar.spy_return_list == [20] # type:ignore[attr-defined] spy.assert_called_once_with(arg=10) assert spy.spy_return == 20 + assert spy.spy_return_iter is None assert spy.spy_return_list == [20] @@ -497,9 +514,11 @@ class Foo(Base): assert Foo.bar(arg=10) == 20 Foo.bar.assert_called_once_with(arg=10) # type:ignore[attr-defined] assert Foo.bar.spy_return == 20 # type:ignore[attr-defined] + assert Foo.bar.spy_return_iter is None # type:ignore[attr-defined] assert Foo.bar.spy_return_list == [20] # type:ignore[attr-defined] spy.assert_called_once_with(arg=10) assert spy.spy_return == 20 + assert spy.spy_return_iter is None assert spy.spy_return_list == [20] @@ -521,9 +540,68 @@ def __call__(self, x): uut.call_like(10) spy.assert_called_once_with(10) assert spy.spy_return == 20 + assert spy.spy_return_iter is None assert spy.spy_return_list == [20] +@pytest.mark.parametrize("iterator", [(i for i in range(3)), iter([0, 1, 2])]) +def test_spy_return_iter(mocker: MockerFixture, iterator: Iterator[int]) -> None: + class Foo: + def bar(self) -> Iterator[int]: + return iterator + + foo = Foo() + spy = mocker.spy(foo, "bar") + result = list(foo.bar()) + + assert result == [0, 1, 2] + assert spy.spy_return is not None + assert spy.spy_return_iter is not None + assert list(spy.spy_return_iter) == result + + [return_value] = spy.spy_return_list + assert isinstance(return_value, Iterator) + + +@pytest.mark.parametrize("iterable", [(0, 1, 2), [0, 1, 2], range(3)]) +def test_spy_return_iter_ignore_plain_iterable( + mocker: MockerFixture, iterable: Iterable[int] +) -> None: + class Foo: + def bar(self) -> Iterable[int]: + return iterable + + foo = Foo() + spy = mocker.spy(foo, "bar") + result = foo.bar() + + assert result == iterable + assert spy.spy_return == result + assert spy.spy_return_iter is None + assert spy.spy_return_list == [result] + + +def test_spy_return_iter_resets(mocker: MockerFixture) -> None: + class Foo: + iterables: Any = [ + (i for i in range(3)), + 99, + ] + + def bar(self) -> Any: + return self.iterables.pop(0) + + foo = Foo() + spy = mocker.spy(foo, "bar") + result_iterator = list(foo.bar()) + + assert result_iterator == [0, 1, 2] + assert list(spy.spy_return_iter) == result_iterator + + assert foo.bar() == 99 + assert spy.spy_return_iter is None + + @pytest.mark.asyncio async def test_instance_async_method_spy(mocker: MockerFixture) -> None: class Foo: