Skip to content

Commit

Permalink
Merge pull request #224 from trenton42/fix-aggregate
Browse files Browse the repository at this point in the history
Fix aggregate with Mongo 3.6
  • Loading branch information
psi29a committed Feb 26, 2018
2 parents 85a7ac7 + 9e4dea7 commit 2a71eda
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 18 deletions.
21 changes: 15 additions & 6 deletions docs/source/NEWS.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
Changelog
=========

Release 18.1.0 (UNRELEASED)
---------------------------

Bugfixes
^^^^^^^^

- Fixed compatibility of `Collection.aggregate()` with PyMongo 3.6


Release 18.0.0 (2018-01-02)
---------------------------

Expand Down Expand Up @@ -30,8 +39,8 @@ Features
- Client authentication by X509 certificates. Use your client certificate when connecting
to MongoDB and then call ``Database.authenticate`` with certificate subject as username,
empty password and ``mechanism="MONGODB-X509"``.
- ``get_version()`` to approximate the behaviour of get_version in PyMongo. One noteable exception
is the omission of searching by random (unindexed) meta-data which should be considered a bad idea
- ``get_version()`` to approximate the behaviour of get_version in PyMongo. One noteable exception
is the omission of searching by random (unindexed) meta-data which should be considered a bad idea
as it may create *very* variable conditions in terms of loading and timing.
- New ``ConnectionPool.drop_database()`` method for easy and convenient destruction of all your precious data.
- ``count()`` to return the number of versions of any given file in GridFS.
Expand All @@ -41,14 +50,14 @@ API Changes

- ``find()``, ``find_one()``, ``find_with_cursor()``, ``count()`` and ``distinct()`` signatures
changed to more closely match PyMongo's counterparts. New signatures are:

- ``find(filter=None, projection=None, skip=0, limit=0, sort=None, **kwargs)``
- ``find_with_cursor(filter=None, projection=None, skip=0, limit=0, sort=None, **kwargs)``
- ``find_one(filter=None, projection=None, **kwargs)``
- ``count(filter=None, **kwargs)``
- ``distinct(key, filter=None, **kwargs)``
Old signatures are now deprecated and will be supported in this and one subsequent releases.

Old signatures are now deprecated and will be supported in this and one subsequent releases.
After that only new signatures will be valid.
- ``cursor`` argument to ``find()`` is deprecated. Please use ``find_with_cursor()`` directly
if you need to iterate over results by batches. ``cursor`` will be supported in this and
Expand Down Expand Up @@ -78,7 +87,7 @@ Features
- ``codec_options`` properties for ``ConnectionPool``, ``Database`` and ``Collection``.
``Collection.with_options(codec_options=CodecOptions(document_class=...))`` is now preferred
over ``Collection.find(..., as_class=...)``.

Bugfixes
^^^^^^^^

Expand Down
70 changes: 66 additions & 4 deletions tests/test_aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ def setUp(self):
self.coll = self.conn.mydb.mycol

@defer.inlineCallbacks
def test_Aggregate(self):
def test_aggregate(self):
"""Test basic aggregation functionality"""
yield self.coll.insert([{"oh": "hai", "lulz": 123},
{"oh": "kthxbye", "lulz": 456},
{"oh": "hai", "lulz": 789}, ], safe=True)
Expand All @@ -42,7 +43,7 @@ def test_Aggregate(self):
{"$project": {"oh": 1, "lolz": "$lulz"}},
{"$group": {"_id": "$oh", "many_lolz": {"$sum": "$lolz"}}},
{"$sort": {"_id": 1}}
])
])

self.assertEqual(len(res), 2)
self.assertEqual(res[0]["_id"], "hai")
Expand All @@ -52,10 +53,71 @@ def test_Aggregate(self):

res = yield self.coll.aggregate([{"$match": {"oh": "hai"}}], full_response=True)

self.assertTrue("ok" in res)
self.assertTrue("result" in res)
self.assertIn("ok", res)
self.assertIn("result", res)
self.assertEqual(len(res["result"]), 2)

res = yield self.coll.aggregate(
[{"$match": {"oh": "hai"}}], full_response=True, initial_batch_size=1
)

self.assertIn("ok", res)
self.assertIn("result", res)
self.assertEqual(len(res["result"]), 2)

@defer.inlineCallbacks
def test_large_batch(self):
"""Test aggregation with a large number of objects"""
cnt = 10000
yield self.coll.insert([{"key": "v{}".format(i), "value": i} for i in range(cnt)])
group = {
"$group": {
"_id": "$key"
}
}

# Default initial batch size (determined by the database)
res = yield self.coll.aggregate([group])
self.assertEqual(len(res), cnt)

# Initial batch size of zero (returns quickly in case of an error)
res = yield self.coll.aggregate([group], initial_batch_size=0)
self.assertEqual(len(res), cnt)

# Small initial batch size
res = yield self.coll.aggregate([group], initial_batch_size=10)
self.assertEqual(len(res), cnt)

# Initial batch size larger than the number of records
res = yield self.coll.aggregate([group], initial_batch_size=(cnt + 10))
self.assertEqual(len(res), cnt)

@defer.inlineCallbacks
def test_large_value(self):
"""Test aggregation with large objects"""
cnt = 2
yield self.coll.insert([{"x": str(i) * 1024 * 1024} for i in range(cnt)])

group = {
"$group": {"_id": "$x"}
}

# Default initial batch size (determined by the database)
res = yield self.coll.aggregate([group])
self.assertEqual(len(res), cnt)

# Initial batch size of zero (returns quickly in case of an error)
res = yield self.coll.aggregate([group], initial_batch_size=0)
self.assertEqual(len(res), cnt)

# Small initial batch size
res = yield self.coll.aggregate([group], initial_batch_size=10)
self.assertEqual(len(res), cnt)

# Initial batch size larger than the number of records
res = yield self.coll.aggregate([group], initial_batch_size=(cnt + 10))
self.assertEqual(len(res), cnt)

@defer.inlineCallbacks
def tearDown(self):
yield self.coll.drop()
Expand Down
39 changes: 31 additions & 8 deletions txmongo/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1120,14 +1120,37 @@ def distinct(self, key, filter=None, _deadline=None, **kwargs):
**params).addCallback(lambda result: result.get("values"))

@timeout
def aggregate(self, pipeline, full_response=False, _deadline=None):
def aggregate(self, pipeline, full_response=False, initial_batch_size=None, _deadline=None):
"""aggregate(pipeline, full_response=False)"""
def on_ok(raw):
if full_response:
return raw
return raw.get("result")
return self._database.command("aggregate", self._collection_name, pipeline = pipeline,
_deadline = _deadline).addCallback(on_ok)

def on_ok(raw, data=None):
if data is None:
data = []
if "firstBatch" in raw["cursor"]:
batch = raw["cursor"]["firstBatch"]
else:
batch = raw["cursor"].get("nextBatch", [])
data += batch
if raw["cursor"]["id"] == 0:
if full_response:
raw["result"] = data
return raw
return data
next_reply = self._database.command(
"getMore", collection=self._collection_name,
getMore=raw["cursor"]["id"]
)
return next_reply.addCallback(on_ok, data)

if initial_batch_size is None:
cursor = {}
else:
cursor = {"batchSize": initial_batch_size}

return self._database.command(
"aggregate", self._collection_name, pipeline=pipeline,
_deadline=_deadline, cursor=cursor
).addCallback(on_ok)

@timeout
def map_reduce(self, map, reduce, full_response=False, **kwargs):
Expand Down Expand Up @@ -1361,7 +1384,7 @@ def done(_):
def on_fail(failure):
failure.trap(defer.FirstError)
failure.value.subFailure.raiseException()

if self.write_concern.acknowledged and not ordered:
return defer.gatherResults(all_responses, consumeErrors=True)\
.addErrback(on_fail)
Expand Down

0 comments on commit 2a71eda

Please sign in to comment.