Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

QuerySets are now mocked correctly when doing things like manager.all…

…()[0:1].get()
  • Loading branch information...
commit 2fe2d6f1b47bc8fb4a3ff97d2aa1267aac89db09 1 parent 3497ac4
@dcramer dcramer authored
View
28 mock_django/managers.py
@@ -7,6 +7,8 @@
"""
import mock
+from .query import QuerySetMock
+
__all__ = ('ManagerMock',)
@@ -68,29 +70,21 @@ def ManagerMock(manager, *return_value):
>>> objects = ManagerMock(Post.objects, Exception())
"""
- def make_get(self):
+ def make_get_query_set(self, actual_model):
def _get(*a, **k):
- results = list(self)
- if len(results) > 1:
- raise self.model.MultipleObjectsReturned
- try:
- return results[0]
- except IndexError:
- raise self.model.DoesNotExist
+ return QuerySetMock(actual_model, *return_value)
return _get
- model = getattr(manager, 'model', None)
- if model:
- model = mock.MagicMock(spec=manager.model())
+ actual_model = getattr(manager, 'model', None)
+ if actual_model:
+ model = mock.MagicMock(spec=actual_model())
else:
model = mock.MagicMock()
m = _ManagerMock()
m.model = model
- m.get = make_get(m)
- if len(return_value) == 1 and isinstance(return_value[0], Exception):
- m.__iter__.side_effect = return_value[0]
- else:
- m.__iter__.side_effect = lambda *a, **k: iter(return_value)
- m.__getitem__ = lambda s, n: list(s)[n]
+ m.get_query_set = make_get_query_set(m, actual_model)
+ m.get = m.get_query_set().get
+ m.__iter__ = m.get_query_set().__iter__
+ m.__getitem__ = m.get_query_set().__getitem__
return m
View
117 mock_django/query.py
@@ -0,0 +1,117 @@
+"""
+mock_django.query
+~~~~~~~~~~~~~~~~~
+
+:copyright: (c) 2012 DISQUS.
+:license: Apache License 2.0, see LICENSE for more details.
+"""
+
+import mock
+
+__all__ = ('QuerySetMock',)
+
+
+class _QuerySetMock(mock.MagicMock):
+ def __init__(self, *args, **kwargs):
+ super(_QuerySetMock, self).__init__(*args, **kwargs)
+ parent = mock.MagicMock()
+ parent.child = self
+ self.__parent = parent
+
+ def _get_child_mock(self, **kwargs):
+ name = kwargs.get('name', '')
+ if name[:2] == name[-2:] == '__':
+ return super(_QuerySetMock, self)._get_child_mock(**kwargs)
+ return self
+
+ def __getattr__(self, name):
+ result = super(_QuerySetMock, self).__getattr__(name)
+ if result is self:
+ result._mock_name = result._mock_new_name = name
+ return result
+
+ def assert_chain_calls(self, *calls):
+ """
+ Asserts that a chained method was called (parents in the chain do not
+ matter, nor are they tracked).
+
+ >>> obj.assert_chain_calls(call.filter(foo='bar'))
+ >>> obj.assert_chain_calls(call.select_related('baz'))
+ """
+ all_calls = self.__parent.mock_calls[:]
+
+ not_found = []
+ for kall in calls:
+ try:
+ all_calls.remove(kall)
+ except ValueError:
+ not_found.append(kall)
+ if not_found:
+ if self.__parent.mock_calls:
+ message = '%r not all found in call list, %d other(s) were:\n%r' % (not_found, len(self.__parent.mock_calls),
+ self.__parent.mock_calls)
+ else:
+ message = 'no calls were found'
+
+ raise AssertionError(message)
+
+
+def QuerySetMock(model, *return_value):
+ """
+ Set the results to two items:
+
+ >>> objects = QuerySetMock(Post, 'return', 'values')
+ >>> assert objects.filter() == objects.all()
+
+ Force an exception:
+
+ >>> objects = QuerySetMock(Post, Exception())
+ """
+
+ def make_get(self):
+ def _get(*a, **k):
+ results = list(self)
+ if len(results) > 1:
+ raise self.model.MultipleObjectsReturned
+ try:
+ return results[0]
+ except IndexError:
+ raise self.model.DoesNotExist
+ return _get
+
+ def make_getitem(self):
+ def _getitem(k):
+ if isinstance(k, slice):
+ self.__start = k.start
+ self.__stop = k.stop
+ else:
+ return list(self)[k]
+ return self
+ return _getitem
+
+ def make_iterator(self):
+ def _iterator(*a, **k):
+ if len(return_value) == 1 and isinstance(return_value[0], Exception):
+ raise return_value[0]
+
+ start = getattr(self, '__start', None)
+ stop = getattr(self, '__stop', None)
+ for x in return_value[start:stop]:
+ yield x
+ return _iterator
+
+ actual_model = model
+ if actual_model:
+ model = mock.MagicMock(spec=actual_model())
+ else:
+ model = mock.MagicMock()
+
+ m = _QuerySetMock()
+ m.__start = None
+ m.__stop = None
+ m.__iter__.side_effect = lambda: iter(m.iterator())
+ m.__getitem__.side_effect = make_getitem(m)
+ m.model = model
+ m.get = make_get(m)
+ m.iterator.side_effect = make_iterator(m)
+ return m
View
2  setup.py
@@ -2,7 +2,7 @@
setup(
name='mock-django',
- version='0.2.0',
+ version='0.3.0',
description='',
license='Apache License 2.0',
author='David Cramer',
View
5 tests/mock_django/managers/tests.py
@@ -64,3 +64,8 @@ def test_call_tracking(self):
self.assertGreater(len(calls), 1)
inst.assert_chain_calls(mock.call.filter(foo='bar'))
inst.assert_chain_calls(mock.call.select_related('baz'))
+
+ def test_getitem_get(self):
+ manager = make_manager()
+ inst = ManagerMock(manager, 'foo')
+ self.assertEquals(inst[0:1].get(), 'foo')
View
0  tests/mock_django/query/__init__.py
No changes.
View
0  tests/mock_django/query/tests.py
No changes.
Please sign in to comment.
Something went wrong with that request. Please try again.