Skip to content

Commit

Permalink
Implement spy_return_list (#417)
Browse files Browse the repository at this point in the history
Fix #378

Co-authored-by: Bruno Oliveira <bruno@soliv.dev>
  • Loading branch information
frank-lenormand and nicoddemus committed Mar 21, 2024
1 parent dc28a0e commit 6d5d6dc
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 23 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Expand Up @@ -4,6 +4,7 @@ Releases
UNRELEASED
----------

* `#417 <https://github.com/pytest-dev/pytest-mock/pull/417>`_: ``spy`` now has ``spy_return_list``, which is a list containing all the values returned by the spied function.
* ``pytest-mock`` now requires ``pytest>=6.2.5``.
* `#410 <https://github.com/pytest-dev/pytest-mock/pull/410>`_: pytest-mock's ``setup.py`` file is removed.
If you relied on this file, e.g. to install pytest using ``setup.py install``,
Expand Down
3 changes: 2 additions & 1 deletion docs/usage.rst
Expand Up @@ -79,7 +79,8 @@ are available (like ``assert_called_once_with`` or ``call_count`` in the example

In addition, spy objects contain two extra attributes:

* ``spy_return``: contains the returned value of the spied function.
* ``spy_return``: contains the last returned value of the spied function.
* ``spy_return_list``: contains a list of all returned values of the spied function (new in ``3.13``).
* ``spy_exception``: contain the last exception value raised by the spied function/method when
it was last called, or ``None`` if no exception was raised.

Expand Down
91 changes: 69 additions & 22 deletions src/pytest_mock/plugin.py
Expand Up @@ -5,11 +5,14 @@
import sys
import unittest.mock
import warnings
from dataclasses import dataclass
from dataclasses import field
from typing import Any
from typing import Callable
from typing import Dict
from typing import Generator
from typing import Iterable
from typing import Iterator
from typing import List
from typing import Mapping
from typing import Optional
Expand Down Expand Up @@ -43,16 +46,55 @@ class PytestMockWarning(UserWarning):
"""Base class for all warnings emitted by pytest-mock."""


@dataclass
class MockCacheItem:
mock: MockType
patch: Optional[Any] = None


@dataclass
class MockCache:
cache: List[MockCacheItem] = field(default_factory=list)

def find(self, mock: MockType) -> MockCacheItem:
the_mock = next(
(mock_item for mock_item in self.cache if mock_item.mock == mock), None
)
if the_mock is None:
raise ValueError("This mock object is not registered")
return the_mock

def add(self, mock: MockType, **kwargs: Any) -> MockCacheItem:
try:
return self.find(mock)
except ValueError:
self.cache.append(MockCacheItem(mock=mock, **kwargs))
return self.cache[-1]

def remove(self, mock: MockType) -> None:
mock_item = self.find(mock)
self.cache.remove(mock_item)

def clear(self) -> None:
self.cache.clear()

def __iter__(self) -> Iterator[MockCacheItem]:
return iter(self.cache)

def __reversed__(self) -> Iterator[MockCacheItem]:
return reversed(self.cache)


class MockerFixture:
"""
Fixture that provides the same interface to functions in the mock module,
ensuring that they are uninstalled at the end of each test.
"""

def __init__(self, config: Any) -> None:
self._patches_and_mocks: List[Tuple[Any, unittest.mock.MagicMock]] = []
self._mock_cache: MockCache = MockCache()
self.mock_module = mock_module = get_mock_module(config)
self.patch = self._Patcher(self._patches_and_mocks, mock_module) # type: MockerFixture._Patcher
self.patch = self._Patcher(self._mock_cache, mock_module) # type: MockerFixture._Patcher
# aliases for convenience
self.Mock = mock_module.Mock
self.MagicMock = mock_module.MagicMock
Expand All @@ -75,7 +117,7 @@ def create_autospec(
m: MockType = self.mock_module.create_autospec(
spec, spec_set, instance, **kwargs
)
self._patches_and_mocks.append((None, m))
self._mock_cache.add(m)
return m

def resetall(
Expand All @@ -93,37 +135,39 @@ def resetall(
else:
supports_reset_mock_with_args = (self.Mock,)

for p, m in self._patches_and_mocks:
for mock_item in self._mock_cache:
# See issue #237.
if not hasattr(m, "reset_mock"):
if not hasattr(mock_item.mock, "reset_mock"):
continue
if isinstance(m, supports_reset_mock_with_args):
m.reset_mock(return_value=return_value, side_effect=side_effect)
# NOTE: The mock may be a dictionary
if hasattr(mock_item.mock, "spy_return_list"):
mock_item.mock.spy_return_list = []
if isinstance(mock_item.mock, supports_reset_mock_with_args):
mock_item.mock.reset_mock(
return_value=return_value, side_effect=side_effect
)
else:
m.reset_mock()
mock_item.mock.reset_mock()

def stopall(self) -> None:
"""
Stop all patchers started by this fixture. Can be safely called multiple
times.
"""
for p, m in reversed(self._patches_and_mocks):
if p is not None:
p.stop()
self._patches_and_mocks.clear()
for mock_item in reversed(self._mock_cache):
if mock_item.patch is not None:
mock_item.patch.stop()
self._mock_cache.clear()

def stop(self, mock: unittest.mock.MagicMock) -> None:
"""
Stops a previous patch or spy call by passing the ``MagicMock`` object
returned by it.
"""
for index, (p, m) in enumerate(self._patches_and_mocks):
if mock is m:
p.stop()
del self._patches_and_mocks[index]
break
else:
raise ValueError("This mock object is not registered")
mock_item = self._mock_cache.find(mock)
if mock_item.patch:
mock_item.patch.stop()
self._mock_cache.remove(mock)

def spy(self, obj: object, name: str) -> MockType:
"""
Expand All @@ -146,6 +190,7 @@ def wrapper(*args, **kwargs):
raise
else:
spy_obj.spy_return = r
spy_obj.spy_return_list.append(r)
return r

async def async_wrapper(*args, **kwargs):
Expand All @@ -158,6 +203,7 @@ async def async_wrapper(*args, **kwargs):
raise
else:
spy_obj.spy_return = r
spy_obj.spy_return_list.append(r)
return r

if asyncio.iscoroutinefunction(method):
Expand All @@ -169,6 +215,7 @@ async def async_wrapper(*args, **kwargs):

spy_obj = self.patch.object(obj, name, side_effect=wrapped, autospec=autospec)
spy_obj.spy_return = None
spy_obj.spy_return_list = []
spy_obj.spy_exception = None
return spy_obj

Expand Down Expand Up @@ -206,8 +253,8 @@ class _Patcher:

DEFAULT = object()

def __init__(self, patches_and_mocks, mock_module):
self.__patches_and_mocks = patches_and_mocks
def __init__(self, mock_cache, mock_module):
self.__mock_cache = mock_cache
self.mock_module = mock_module

def _start_patch(
Expand All @@ -219,7 +266,7 @@ def _start_patch(
"""
p = mock_func(*args, **kwargs)
mocked: MockType = p.start()
self.__patches_and_mocks.append((p, mocked))
self.__mock_cache.add(mock=mocked, patch=p)
if hasattr(mocked, "reset_mock"):
# check if `mocked` is actually a mock object, as depending on autospec or target
# parameters `mocked` can be anything
Expand Down
21 changes: 21 additions & 0 deletions tests/test_pytest_mock.py
Expand Up @@ -279,8 +279,13 @@ def bar(self, arg):
assert other.bar(arg=10) == 20
foo.bar.assert_called_once_with(arg=10) # type:ignore[attr-defined]
assert foo.bar.spy_return == 20 # type:ignore[attr-defined]
assert foo.bar.spy_return_list == [20] # type:ignore[attr-defined]
spy.assert_called_once_with(arg=10)
assert spy.spy_return == 20
assert foo.bar(arg=11) == 22
assert foo.bar(arg=12) == 24
assert spy.spy_return == 24
assert spy.spy_return_list == [20, 22, 24]


# Ref: https://docs.python.org/3/library/exceptions.html#exception-hierarchy
Expand Down Expand Up @@ -358,10 +363,12 @@ def bar(self, x):

spy = mocker.spy(Foo, "bar")
assert spy.spy_return is None
assert spy.spy_return_list == []
assert spy.spy_exception is None

Foo().bar(10)
assert spy.spy_return == 30
assert spy.spy_return_list == [30]
assert spy.spy_exception is None

# Testing spy can still be reset (#237).
Expand All @@ -370,10 +377,12 @@ def bar(self, x):
with pytest.raises(ValueError):
Foo().bar(0)
assert spy.spy_return is None
assert spy.spy_return_list == []
assert str(spy.spy_exception) == "invalid x"

Foo().bar(15)
assert spy.spy_return == 45
assert spy.spy_return_list == [45]
assert spy.spy_exception is None


Expand Down Expand Up @@ -409,6 +418,7 @@ class Foo(Base):
calls = [mocker.call(foo, arg=10), mocker.call(other, arg=10)]
assert spy.call_args_list == calls
assert spy.spy_return == 20
assert spy.spy_return_list == [20, 20]


@skip_pypy
Expand All @@ -422,8 +432,10 @@ def bar(cls, arg):
assert Foo.bar(arg=10) == 20
Foo.bar.assert_called_once_with(arg=10) # type:ignore[attr-defined]
assert Foo.bar.spy_return == 20 # type:ignore[attr-defined]
assert Foo.bar.spy_return_list == [20] # type:ignore[attr-defined]
spy.assert_called_once_with(arg=10)
assert spy.spy_return == 20
assert spy.spy_return_list == [20]


@skip_pypy
Expand All @@ -440,8 +452,10 @@ class Foo(Base):
assert Foo.bar(arg=10) == 20
Foo.bar.assert_called_once_with(arg=10) # type:ignore[attr-defined]
assert Foo.bar.spy_return == 20 # type:ignore[attr-defined]
assert Foo.bar.spy_return_list == [20] # type:ignore[attr-defined]
spy.assert_called_once_with(arg=10)
assert spy.spy_return == 20
assert spy.spy_return_list == [20]


@skip_pypy
Expand All @@ -460,8 +474,10 @@ def bar(cls, arg):
assert Foo.bar(arg=10) == 20
Foo.bar.assert_called_once_with(arg=10) # type:ignore[attr-defined]
assert Foo.bar.spy_return == 20 # type:ignore[attr-defined]
assert Foo.bar.spy_return_list == [20] # type:ignore[attr-defined]
spy.assert_called_once_with(arg=10)
assert spy.spy_return == 20
assert spy.spy_return_list == [20]


@skip_pypy
Expand All @@ -475,8 +491,10 @@ def bar(arg):
assert Foo.bar(arg=10) == 20
Foo.bar.assert_called_once_with(arg=10) # type:ignore[attr-defined]
assert Foo.bar.spy_return == 20 # type:ignore[attr-defined]
assert Foo.bar.spy_return_list == [20] # type:ignore[attr-defined]
spy.assert_called_once_with(arg=10)
assert spy.spy_return == 20
assert spy.spy_return_list == [20]


@skip_pypy
Expand All @@ -493,8 +511,10 @@ class Foo(Base):
assert Foo.bar(arg=10) == 20
Foo.bar.assert_called_once_with(arg=10) # type:ignore[attr-defined]
assert Foo.bar.spy_return == 20 # type:ignore[attr-defined]
assert Foo.bar.spy_return_list == [20] # type:ignore[attr-defined]
spy.assert_called_once_with(arg=10)
assert spy.spy_return == 20
assert spy.spy_return_list == [20]


def test_callable_like_spy(testdir: Any, mocker: MockerFixture) -> None:
Expand All @@ -515,6 +535,7 @@ def __call__(self, x):
uut.call_like(10)
spy.assert_called_once_with(10)
assert spy.spy_return == 20
assert spy.spy_return_list == [20]


async def test_instance_async_method_spy(mocker: MockerFixture) -> None:
Expand Down

0 comments on commit 6d5d6dc

Please sign in to comment.