Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions Lib/unittest/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -1969,6 +1969,12 @@ def __aiter__():


def _set_return_value(mock, method, name):
# If _mock_wraps is present then attach it so that wrapped object
# is used for return value is used when called.
if mock._mock_wraps is not None:
method._mock_wraps = getattr(mock._mock_wraps, name)
return

fixed = _return_values.get(name, DEFAULT)
if fixed is not DEFAULT:
method.return_value = fixed
Expand Down
47 changes: 47 additions & 0 deletions Lib/unittest/test/testmock/testmock.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,53 @@ def method(self): pass
self.assertRaises(StopIteration, mock.method)


def test_magic_method_wraps_dict(self):
data = {'foo': 'bar'}

wrapped_dict = MagicMock(wraps=data)
self.assertEqual(wrapped_dict.get('foo'), 'bar')
self.assertEqual(wrapped_dict['foo'], 'bar')
self.assertTrue('foo' in wrapped_dict)

# return_value is non-sentinel and takes precedence over wrapped value.
wrapped_dict.get.return_value = 'return_value'
self.assertEqual(wrapped_dict.get('foo'), 'return_value')

# return_value is sentinel and hence wrapped value is returned.
wrapped_dict.get.return_value = sentinel.DEFAULT
self.assertEqual(wrapped_dict.get('foo'), 'bar')

self.assertEqual(wrapped_dict.get('baz'), None)
with self.assertRaises(KeyError):
wrapped_dict['baz']
self.assertFalse('bar' in wrapped_dict)

data['baz'] = 'spam'
self.assertEqual(wrapped_dict.get('baz'), 'spam')
self.assertEqual(wrapped_dict['baz'], 'spam')
self.assertTrue('baz' in wrapped_dict)

del data['baz']
self.assertEqual(wrapped_dict.get('baz'), None)


def test_magic_method_wraps_class(self):

class Foo:

def __getitem__(self, index):
return index

def __custom_method__(self):
return "foo"


klass = MagicMock(wraps=Foo)
obj = klass()
self.assertEqual(obj.__getitem__(2), 2)
self.assertEqual(obj.__custom_method__(), "foo")


def test_exceptional_side_effect(self):
mock = Mock(side_effect=AttributeError)
self.assertRaises(AttributeError, mock)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Ensure, if ``wraps`` is supplied to :class:`unittest.mock.MagicMock`, it is used
to calculate return values for the magic methods instead of using the default
return values. Patch by Karthikeyan Singaravelan.