Skip to content
This repository

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
tree: 3edb99afb8
Fetching contributors…

Cannot retrieve contributors at this time

file 84 lines (67 sloc) 2.269 kb
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
"""
mock_django.query
~~~~~~~~~~~~~~~~~

:copyright: (c) 2012 DISQUS.
:license: Apache License 2.0, see LICENSE for more details.
"""

import mock
from .shared import SharedMock

__all__ = ('QuerySetMock',)


def QuerySetMock(model, *return_value):
    """
Set the results to two items:

>>> objects = QuerySetMock(Post, 'return', 'values')
>>> assert objects.filter() == objects.all()

Force an exception:

>>> objects = QuerySetMock(Post, Exception())

Note that only methods returning querysets are currently
explicitly supported; since we use SharedMock, others all behave
as if they did, so use with caution:

>>> objects.count() == objects.all()
True
"""

    def make_get(self, model):
        def _get(*a, **k):
            results = list(self)
            if len(results) > 1:
                raise model.MultipleObjectsReturned
            try:
                return results[0]
            except IndexError:
                raise model.DoesNotExist
        return _get

    def make_getitem(self):
        def _getitem(k):
            if isinstance(k, slice):
                self.__start = k.start
                self.__stop = k.stop
            else:
                return list(self)[k]
            return self
        return _getitem

    def make_iterator(self):
        def _iterator(*a, **k):
            if len(return_value) == 1 and isinstance(return_value[0], Exception):
                raise return_value[0]

            start = getattr(self, '__start', None)
            stop = getattr(self, '__stop', None)
            for x in return_value[start:stop]:
                yield x
        return _iterator

    actual_model = model
    if actual_model:
        model = mock.MagicMock(spec=actual_model())
    else:
        model = mock.MagicMock()

    m = SharedMock()
    m.__start = None
    m.__stop = None
    m.__iter__.side_effect = lambda: iter(m.iterator())
    m.__getitem__.side_effect = make_getitem(m)
    m.model = model
    m.get = make_get(m, actual_model)

    # Note since this is a SharedMock, *all* auto-generated child
    # attributes will have the same side_effect ... might not make
    # sense for some like count().
    m.iterator.side_effect = make_iterator(m)
    return m
Something went wrong with that request. Please try again.