diff --git a/hsdbi/mongo.py b/hsdbi/mongo.py index 470f1a8..86a4de7 100644 --- a/hsdbi/mongo.py +++ b/hsdbi/mongo.py @@ -179,7 +179,10 @@ def all(self, projection=None): Returns: pymongo.cursor.Cursor with results. """ - return self._collection.find({}, projection_dict(projection)) + if projection: + return self._collection.find({}, projection_dict(projection)) + else: + return self._collection.find() def commit(self): """Does nothing for a MongoRepository.""" @@ -248,9 +251,12 @@ def get(self, expect=True, projection=None, **kwargs): NotFoundError: if the item is not found in the database, but it is expected. """ - item = next(self._collection.find(kwargs, - projection_dict(projection)), - None) + if projection: + item = next(self._collection.find(kwargs, + projection_dict(projection)), + None) + else: + item = next(self._collection.find(kwargs), None) if not item and expect: raise errors.NotFoundError(pk=kwargs, table=self._collection_name) return item @@ -268,4 +274,19 @@ def search(self, projection=None, **kwargs): Returns: pymongo.cursor.Cursor with matching results (if any). """ - return self._collection.find(kwargs, projection_dict(projection)) + if projection: + return self._collection.find(kwargs, projection_dict(projection)) + else: + return self._collection.find(kwargs) + + def update(self, doc): + """Update the doc, saving attribute states into the db. + + Args: + doc: the document to update. + """ + _id = doc['_id'] + doc.pop('_id') + self._collection.update_one( + {'_id': _id}, {'$set': doc}) + doc['_id'] = _id diff --git a/setup.py b/setup.py index 0c3f51c..0cc0c76 100644 --- a/setup.py +++ b/setup.py @@ -7,12 +7,12 @@ setup( name='hsdbi', packages=find_packages(exclude=['testing']), - version='0.1a12', + version='0.1a13', description='A simple interface for accessing databases.', author='Tim Niven', author_email='tim.niven.public@gmail.com', url='https://github.com/timniven/hsdbi', - download_url='https://github.com/timniven/hsdbi/archive/0.1a12.tar.gz', + download_url='https://github.com/timniven/hsdbi/archive/0.1a13.tar.gz', license='MIT', classifiers=[ 'Development Status :: 3 - Alpha', diff --git a/testing/tests.py b/testing/tests.py index 1578924..6cbd4f6 100644 --- a/testing/tests.py +++ b/testing/tests.py @@ -281,14 +281,10 @@ def setUp(self): db=db, collection_name='foo') self.test_cases = [('ABC', 'abc'), ('DEF', 'def'), ('GHI', 'def')] - for test_case in self.test_cases: - if self.repository.exists(_id=test_case[0]): - self.repository.delete(_id=test_case[0]) + self.repository.delete_all_records() def tearDown(self): - for test_case in self.test_cases: - if self.repository.exists(_id=test_case[0]): - self.repository.delete(_id=test_case[0]) + self.repository.delete_all_records() def _insert_one(self): self.repository.add(_id='ABC', name='abc') @@ -382,3 +378,12 @@ def test_search_with_projection(self): self.assertEqual(len(items[1]), 1) self.assertEqual(items[0]['_id'], 'DEF') self.assertEqual(items[1]['_id'], 'GHI') + + def test_update(self): + self._insert_one() + doc = self.repository.get(_id='ABC') + doc['new_attr'] = 123 + self.repository.update(doc) + doc2 = self.repository.get(_id='ABC') + self.assertEqual(doc2['new_attr'], 123) + self.assertEqual(doc2['_id'], doc['_id'])