Skip to content

Commit

Permalink
Adds count, modifies remove and find
Browse files Browse the repository at this point in the history
Adds count method to collection. Adds possibility to remove by id. Adds possibility to return in the result of find
only selected fields. Changes the names of parameters of find and remove to comply to PyMongo but keeps the old ones
for backward compatibility.
  • Loading branch information
ca77y committed Aug 17, 2012
1 parent 263609b commit 7475fbf
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 5 deletions.
41 changes: 36 additions & 5 deletions mongomock/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import operator
import warnings
import re

from sentinels import NOTHING
Expand Down Expand Up @@ -27,6 +28,10 @@ def _all_op(doc_val, search_val):
dv = _force_list(doc_val)
return all(x in dv for x in search_val)

def _print_deprecation_warning(old_param_name, new_param_name):
warnings.warn("'%s' has been deprecated to be in line with pymongo implementation, "
"a new parameter '%s' should be used instead. the old parameter will be kept for backward "
"compatibility purposes." % old_param_name, new_param_name, DeprecationWarning)

OPERATOR_MAP = {'$ne': operator.ne,
'$gt': _not_nothing_and(operator.gt),
Expand Down Expand Up @@ -111,9 +116,25 @@ def update(self, spec, document):
existing_document.clear()
existing_document.update(document)
existing_document['_id'] = document_id
def find(self, filter=None):
dataset = (document.copy() for document in self._iter_documents(filter))
def find(self, spec=None, fields=None, filter=None):
if filter is not None:
_print_deprecation_warning('filter', 'spec')
if spec is None:
spec = filter
dataset = (self._copy_only_fields(document, fields) for document in self._iter_documents(spec))
return Cursor(dataset)
def _copy_only_fields(self, doc, fields):
"""Copy only the specified fields."""
if fields is None:
return doc.copy()
doc_copy = {}
if not fields:
fields = ["_id"]
for key in fields:
if key in doc:
doc_copy[key] = doc[key]
return doc_copy

def _iter_documents(self, filter=None):
return (document for document in itervalues(self._documents) if self._filter_applies(filter, document))
def find_one(self, filter=None):
Expand Down Expand Up @@ -148,13 +169,23 @@ def _filter_applies(self, search_filter, document):

return True

def remove(self, search_filter=None):
"""Remove objects matching search_filter from the collection."""
to_delete = list(self.find(filter=search_filter))
def remove(self, spec_or_id=None, search_filter=None):
"""Remove objects matching spec_or_id from the collection."""
if search_filter is not None:
_print_deprecation_warning('search_filter', 'spec_or_id')
if spec_or_id is None:
spec_or_id = search_filter if search_filter else {}
if not isinstance(spec_or_id, dict):
spec_or_id = {'_id': spec_or_id}
to_delete = list(self.find(spec=spec_or_id))
for doc in to_delete:
doc_id = doc['_id']
del self._documents[doc_id]

def count(self):
return len(self._documents)


class Cursor(object):
def __init__(self, dataset):
super(Cursor, self).__init__()
Expand Down
27 changes: 27 additions & 0 deletions tests/test__mongomock.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,14 @@ def test__bulk_insert(self):
self.assertItemsEqual(self.collection.find(), expected_objects)
# make sure objects were not changed in-place
self.assertEquals(objects, original_objects)
def test__count(self):
actual = self.collection.count()
self.assertEqual(0, actual)
data = dict(a=1, b=2)
self.collection.insert(data)
actual = self.collection.count()
self.assertEqual(1, actual)

class DocumentTest(FakePymongoDatabaseTest):
def setUp(self):
super(DocumentTest, self).setUp()
Expand Down Expand Up @@ -258,6 +266,19 @@ def test__find_sets(self):
self._assert_find({'x':{'$all':[2,5]}}, 'x', (prime,))
self._assert_find({'x':{'$all':[7,8]}}, 'x', ())

def test__return_only_selected_fields(self):
rec = {'name':'Chucky', 'type':'doll', 'model':'v6'}
self.collection.insert(rec)
result = list(self.collection.find({'name':'Chucky'}, fields=['type']))
self.assertEqual('doll', result[0]['type'])

def test__default_fields_to_id_if_empty(self):
rec = {'name':'Chucky', 'type':'doll', 'model':'v6'}
rec_id = self.collection.insert(rec)
result = list(self.collection.find({'name':'Chucky'}, fields=[]))
self.assertEqual(1, len(result[0]))
self.assertEqual(rec_id, result[0]['_id'])

class RemoveTest(DocumentTest):
"""Test the remove method."""
def test__remove(self):
Expand Down Expand Up @@ -286,6 +307,12 @@ def test__remove(self):
self.collection.remove({'name': 'sam'})
docs = list(self.collection.find())
self.assertEqual(len(docs), 0)
def test__remove_by_id(self):
expected = self.collection.count()
bob = {'name': 'bob'}
bob_id = self.collection.insert(bob)
self.collection.remove(bob_id)
self.assertEqual(expected, self.collection.count())

class UpdateTest(DocumentTest):
def test__update(self):
Expand Down

0 comments on commit 7475fbf

Please sign in to comment.