Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement spy_return_list. #417

Merged
merged 5 commits into from Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/usage.rst
Expand Up @@ -79,7 +79,8 @@ are available (like ``assert_called_once_with`` or ``call_count`` in the example

In addition, spy objects contain two extra attributes:

* ``spy_return``: contains the returned value of the spied function.
* ``spy_return``: contains the last returned value of the spied function.
* ``spy_return_list``: contains a list of all returned value of the spied function.
nicoddemus marked this conversation as resolved.
Show resolved Hide resolved
* ``spy_exception``: contain the last exception value raised by the spied function/method when
it was last called, or ``None`` if no exception was raised.

Expand Down
91 changes: 69 additions & 22 deletions src/pytest_mock/plugin.py
Expand Up @@ -5,11 +5,14 @@
import sys
import unittest.mock
import warnings
from dataclasses import dataclass
from dataclasses import field
from typing import Any
from typing import Callable
from typing import Dict
from typing import Generator
from typing import Iterable
from typing import Iterator
from typing import List
from typing import Mapping
from typing import Optional
Expand Down Expand Up @@ -43,16 +46,55 @@ class PytestMockWarning(UserWarning):
"""Base class for all warnings emitted by pytest-mock."""


@dataclass
class MockCacheItem:
mock: MockType
patch: Optional[Any] = None


@dataclass
class MockCache:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Loved this refactoring. 👍

cache: List[MockCacheItem] = field(default_factory=list)

def find(self, mock: MockType) -> MockCacheItem:
the_mock = next(
(mock_item for mock_item in self.cache if mock_item.mock == mock), None
)
if the_mock is None:
raise ValueError("This mock object is not registered")
return the_mock

def add(self, mock: MockType, **kwargs: Any) -> MockCacheItem:
try:
return self.find(mock)
except ValueError:
self.cache.append(MockCacheItem(mock=mock, **kwargs))
return self.cache[-1]

def remove(self, mock: MockType) -> None:
mock_item = self.find(mock)
self.cache.remove(mock_item)

def clear(self) -> None:
self.cache.clear()

def __iter__(self) -> Iterator[MockCacheItem]:
return iter(self.cache)

def __reversed__(self) -> Iterator[MockCacheItem]:
return reversed(self.cache)


class MockerFixture:
"""
Fixture that provides the same interface to functions in the mock module,
ensuring that they are uninstalled at the end of each test.
"""

def __init__(self, config: Any) -> None:
self._patches_and_mocks: List[Tuple[Any, unittest.mock.MagicMock]] = []
self._mock_cache: MockCache = MockCache()
self.mock_module = mock_module = get_mock_module(config)
self.patch = self._Patcher(self._patches_and_mocks, mock_module) # type: MockerFixture._Patcher
self.patch = self._Patcher(self._mock_cache, mock_module) # type: MockerFixture._Patcher
# aliases for convenience
self.Mock = mock_module.Mock
self.MagicMock = mock_module.MagicMock
Expand All @@ -75,7 +117,7 @@ def create_autospec(
m: MockType = self.mock_module.create_autospec(
spec, spec_set, instance, **kwargs
)
self._patches_and_mocks.append((None, m))
self._mock_cache.add(m)
return m

def resetall(
Expand All @@ -93,37 +135,39 @@ def resetall(
else:
supports_reset_mock_with_args = (self.Mock,)

for p, m in self._patches_and_mocks:
for mock_item in self._mock_cache:
# See issue #237.
if not hasattr(m, "reset_mock"):
if not hasattr(mock_item.mock, "reset_mock"):
continue
if isinstance(m, supports_reset_mock_with_args):
m.reset_mock(return_value=return_value, side_effect=side_effect)
# NOTE: The mock may be a dictionary
if hasattr(mock_item.mock, "spy_return_list"):
mock_item.mock.spy_return_list = []
if isinstance(mock_item.mock, supports_reset_mock_with_args):
mock_item.mock.reset_mock(
return_value=return_value, side_effect=side_effect
)
else:
m.reset_mock()
mock_item.mock.reset_mock()

def stopall(self) -> None:
"""
Stop all patchers started by this fixture. Can be safely called multiple
times.
"""
for p, m in reversed(self._patches_and_mocks):
if p is not None:
p.stop()
self._patches_and_mocks.clear()
for mock_item in reversed(self._mock_cache):
if mock_item.patch is not None:
mock_item.patch.stop()
self._mock_cache.clear()

def stop(self, mock: unittest.mock.MagicMock) -> None:
"""
Stops a previous patch or spy call by passing the ``MagicMock`` object
returned by it.
"""
for index, (p, m) in enumerate(self._patches_and_mocks):
if mock is m:
p.stop()
del self._patches_and_mocks[index]
break
else:
raise ValueError("This mock object is not registered")
mock_item = self._mock_cache.find(mock)
if mock_item.patch:
mock_item.patch.stop()
self._mock_cache.remove(mock)

def spy(self, obj: object, name: str) -> MockType:
"""
Expand All @@ -146,6 +190,7 @@ def wrapper(*args, **kwargs):
raise
else:
spy_obj.spy_return = r
spy_obj.spy_return_list.append(r)
return r

async def async_wrapper(*args, **kwargs):
Expand All @@ -158,6 +203,7 @@ async def async_wrapper(*args, **kwargs):
raise
else:
spy_obj.spy_return = r
spy_obj.spy_return_list.append(r)
return r

if asyncio.iscoroutinefunction(method):
Expand All @@ -169,6 +215,7 @@ async def async_wrapper(*args, **kwargs):

spy_obj = self.patch.object(obj, name, side_effect=wrapped, autospec=autospec)
spy_obj.spy_return = None
spy_obj.spy_return_list = []
spy_obj.spy_exception = None
return spy_obj

Expand Down Expand Up @@ -206,8 +253,8 @@ class _Patcher:

DEFAULT = object()

def __init__(self, patches_and_mocks, mock_module):
self.__patches_and_mocks = patches_and_mocks
def __init__(self, mock_cache, mock_module):
self.__mock_cache = mock_cache
self.mock_module = mock_module

def _start_patch(
Expand All @@ -219,7 +266,7 @@ def _start_patch(
"""
p = mock_func(*args, **kwargs)
mocked: MockType = p.start()
self.__patches_and_mocks.append((p, mocked))
self.__mock_cache.add(mock=mocked, patch=p)
if hasattr(mocked, "reset_mock"):
# check if `mocked` is actually a mock object, as depending on autospec or target
# parameters `mocked` can be anything
Expand Down
20 changes: 20 additions & 0 deletions tests/test_pytest_mock.py
Expand Up @@ -279,8 +279,12 @@ def bar(self, arg):
assert other.bar(arg=10) == 20
foo.bar.assert_called_once_with(arg=10) # type:ignore[attr-defined]
assert foo.bar.spy_return == 20 # type:ignore[attr-defined]
assert foo.bar.spy_return_list == [20] # type:ignore[attr-defined]
spy.assert_called_once_with(arg=10)
assert spy.spy_return == 20
assert foo.bar(arg=11) == 22
assert foo.bar(arg=12) == 24
assert spy.spy_return_list == [20, 22, 24]
nicoddemus marked this conversation as resolved.
Show resolved Hide resolved


# Ref: https://docs.python.org/3/library/exceptions.html#exception-hierarchy
Expand Down Expand Up @@ -358,10 +362,12 @@ def bar(self, x):

spy = mocker.spy(Foo, "bar")
assert spy.spy_return is None
assert spy.spy_return_list == []
assert spy.spy_exception is None

Foo().bar(10)
assert spy.spy_return == 30
assert spy.spy_return_list == [30]
assert spy.spy_exception is None

# Testing spy can still be reset (#237).
Expand All @@ -370,10 +376,12 @@ def bar(self, x):
with pytest.raises(ValueError):
Foo().bar(0)
assert spy.spy_return is None
assert spy.spy_return_list == []
assert str(spy.spy_exception) == "invalid x"

Foo().bar(15)
assert spy.spy_return == 45
assert spy.spy_return_list == [45]
assert spy.spy_exception is None


Expand Down Expand Up @@ -409,6 +417,7 @@ class Foo(Base):
calls = [mocker.call(foo, arg=10), mocker.call(other, arg=10)]
assert spy.call_args_list == calls
assert spy.spy_return == 20
assert spy.spy_return_list == [20, 20]


@skip_pypy
Expand All @@ -422,8 +431,10 @@ def bar(cls, arg):
assert Foo.bar(arg=10) == 20
Foo.bar.assert_called_once_with(arg=10) # type:ignore[attr-defined]
assert Foo.bar.spy_return == 20 # type:ignore[attr-defined]
assert Foo.bar.spy_return_list == [20] # type:ignore[attr-defined]
spy.assert_called_once_with(arg=10)
assert spy.spy_return == 20
assert spy.spy_return_list == [20]


@skip_pypy
Expand All @@ -440,8 +451,10 @@ class Foo(Base):
assert Foo.bar(arg=10) == 20
Foo.bar.assert_called_once_with(arg=10) # type:ignore[attr-defined]
assert Foo.bar.spy_return == 20 # type:ignore[attr-defined]
assert Foo.bar.spy_return_list == [20] # type:ignore[attr-defined]
spy.assert_called_once_with(arg=10)
assert spy.spy_return == 20
assert spy.spy_return_list == [20]


@skip_pypy
Expand All @@ -460,8 +473,10 @@ def bar(cls, arg):
assert Foo.bar(arg=10) == 20
Foo.bar.assert_called_once_with(arg=10) # type:ignore[attr-defined]
assert Foo.bar.spy_return == 20 # type:ignore[attr-defined]
assert Foo.bar.spy_return_list == [20] # type:ignore[attr-defined]
spy.assert_called_once_with(arg=10)
assert spy.spy_return == 20
assert spy.spy_return_list == [20]


@skip_pypy
Expand All @@ -475,8 +490,10 @@ def bar(arg):
assert Foo.bar(arg=10) == 20
Foo.bar.assert_called_once_with(arg=10) # type:ignore[attr-defined]
assert Foo.bar.spy_return == 20 # type:ignore[attr-defined]
assert Foo.bar.spy_return_list == [20] # type:ignore[attr-defined]
spy.assert_called_once_with(arg=10)
assert spy.spy_return == 20
assert spy.spy_return_list == [20]


@skip_pypy
Expand All @@ -493,8 +510,10 @@ class Foo(Base):
assert Foo.bar(arg=10) == 20
Foo.bar.assert_called_once_with(arg=10) # type:ignore[attr-defined]
assert Foo.bar.spy_return == 20 # type:ignore[attr-defined]
assert Foo.bar.spy_return_list == [20] # type:ignore[attr-defined]
spy.assert_called_once_with(arg=10)
assert spy.spy_return == 20
assert spy.spy_return_list == [20]


def test_callable_like_spy(testdir: Any, mocker: MockerFixture) -> None:
Expand All @@ -515,6 +534,7 @@ def __call__(self, x):
uut.call_like(10)
spy.assert_called_once_with(10)
assert spy.spy_return == 20
assert spy.spy_return_list == [20]


async def test_instance_async_method_spy(mocker: MockerFixture) -> None:
Expand Down