Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

Added assert_chain_calls to the manager

  • Loading branch information...
commit 58d94d519ead9076180267abb82e5b3660cd43db 1 parent 3492ee2
@dcramer dcramer authored
Showing with 40 additions and 12 deletions.
  1. +27 −0 mock_django/managers.py
  2. +13 −12 tests/mock_django/managers/tests.py
View
27 mock_django/managers.py
@@ -12,6 +12,12 @@
class _ManagerMock(mock.MagicMock):
+ def __init__(self, *args, **kwargs):
+ super(_ManagerMock, self).__init__(*args, **kwargs)
+ parent = mock.MagicMock()
+ parent.child = self
+ self.__parent = parent
+
def _get_child_mock(self, **kwargs):
name = kwargs.get('name', '')
if name[:2] == name[-2:] == '__':
@@ -24,6 +30,27 @@ def __getattr__(self, name):
result._mock_name = result._mock_new_name = name
return result
+ def assert_chain_calls(self, *calls):
+ """
+ Asserts that a chained method was called (parents in the chain do not
+ matter, nor are they tracked).
+
+ >>> obj.assert_chain_calls(call.filter(foo='bar'))
+ >>> obj.assert_chain_calls(call.select_related('baz'))
+ """
+ all_calls = list(self.__parent.mock_calls)
+
+ not_found = []
+ for kall in calls:
+ try:
+ all_calls.remove(kall)
+ except ValueError:
+ not_found.append(kall)
+ if not_found:
+ raise AssertionError(
+ '%r not all found in call list' % (tuple(not_found),)
+ )
+
def ManagerMock(manager, *return_value):
"""
View
25 tests/mock_django/managers/tests.py
@@ -13,28 +13,29 @@ def make_manager():
class ManagerMockTestCase(TestCase):
def test_iter(self):
manager = make_manager()
- mock = ManagerMock(manager, 'foo')
- self.assertEquals(list(mock.all()), ['foo'])
+ inst = ManagerMock(manager, 'foo')
+ self.assertEquals(list(inst.all()), ['foo'])
def test_getitem(self):
manager = make_manager()
- mock = ManagerMock(manager, 'foo')
- self.assertEquals(mock.all()[0], 'foo')
+ inst = ManagerMock(manager, 'foo')
+ self.assertEquals(inst.all()[0], 'foo')
def test_returns_self(self):
manager = make_manager()
- mock = ManagerMock(manager, 'foo')
+ inst = ManagerMock(manager, 'foo')
- self.assertEquals(mock.all(), mock)
+ self.assertEquals(inst.all(), inst)
def test_call_tracking(self):
# only works in >= mock 0.8
manager = make_manager()
- mock = ManagerMock(manager, 'foo')
+ inst = ManagerMock(manager, 'foo')
- mock = mock.filter(foo='bar').select_related('baz')
- calls = mock.mock_calls
+ inst.filter(foo='bar').select_related('baz')
- self.assertEquals(len(calls), 2)
- self.assertEquals(calls[0], mock.call.filter(foo='bar'))
- self.assertEquals(calls[1], mock.call.select_related('baz'))
+ calls = inst.mock_calls
+
+ self.assertGreater(len(calls), 1)
+ inst.assert_chain_calls(mock.call.filter(foo='bar'))
+ inst.assert_chain_calls(mock.call.select_related('baz'))
Please sign in to comment.
Something went wrong with that request. Please try again.