Skip to content

Commit

Permalink
new-style wire protocol implementation for find*() methods
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyaSkriblovsky committed Apr 2, 2020
1 parent d1df147 commit b016f7e
Show file tree
Hide file tree
Showing 5 changed files with 209 additions and 58 deletions.
4 changes: 4 additions & 0 deletions tests/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from twisted.trial import unittest
from twisted.internet import defer

import txmongo
import txmongo.filter as qf
from pymongo.errors import OperationFailure
Expand Down Expand Up @@ -110,6 +111,9 @@ def test_Comment(self):

@defer.inlineCallbacks
def test_Snapshot(self):
ismaster = yield self.db.command('ismaster')
if ismaster['maxWireVersion'] >= 7:
raise unittest.SkipTest('snapshot option is only for MongoDB <=3.6')
yield self.__test_simple_filter(qf.snapshot(), "snapshot", True)

@defer.inlineCallbacks
Expand Down
6 changes: 4 additions & 2 deletions tests/test_replicaset.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def test_find_with_timeout(self):

yield conn.db.coll.insert({'x': 42}, safe=True)

yield self.__mongod[0].stop()
yield self.__mongod[0].kill(signal.SIGSTOP)

while True:
try:
Expand All @@ -200,6 +200,7 @@ def test_find_with_timeout(self):
pass

finally:
yield self.__mongod[0].kill(signal.SIGCONT)
yield conn.disconnect()
self.flushLoggedErrors(AutoReconnect)

Expand All @@ -213,7 +214,7 @@ def test_find_with_deadline(self):

yield conn.db.coll.insert({'x': 42}, safe=True)

yield self.__mongod[0].stop()
yield self.__mongod[0].kill(signal.SIGSTOP)

while True:
try:
Expand All @@ -225,6 +226,7 @@ def test_find_with_deadline(self):
pass

finally:
yield self.__mongod[0].kill(signal.SIGCONT)
yield conn.disconnect()
self.flushLoggedErrors(AutoReconnect)

Expand Down
211 changes: 172 additions & 39 deletions txmongo/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,54 +436,137 @@ def query():
new_kwargs = self._find_args_compat(*args, **kwargs)
return self.__real_find_with_cursor(**new_kwargs)

def __real_find_with_cursor(self, filter=None, projection=None, skip=0, limit=0, sort=None, batch_size=0,**kwargs):

if filter is None:
filter = SON()

if not isinstance(filter, dict):
raise TypeError("TxMongo: filter must be an instance of dict.")
if not isinstance(projection, (dict, list)) and projection is not None:
raise TypeError("TxMongo: projection must be an instance of dict or list.")
if not isinstance(skip, int):
raise TypeError("TxMongo: skip must be an instance of int.")
if not isinstance(limit, int):
raise TypeError("TxMongo: limit must be an instance of int.")
if not isinstance(batch_size, int):
raise TypeError("TxMongo: batch_size must be an instance of int.")

projection = self._normalize_fields_projection(projection)

filter = self.__apply_find_filter(filter, sort)
_MODIFIERS = SON([
('$query', 'filter'),
('$orderby', 'sort'),
('$hint', 'hint'),
('$comment', 'comment'),
('$maxScan', 'maxScan'),
('$maxTimeMS', 'maxTimeMS'),
('$max', 'max'),
('$min', 'min'),
('$returnKey', 'returnKey'),
('$showRecordId', 'showRecordId'),
('$showDiskLoc', 'showRecordId'), # <= MongoDB 3.0
('$snapshot', 'snapshot'), # <= MongoDB 4.0
])

@classmethod
def _gen_find_command(cls, coll, filter_with_modifiers, projection, skip, limit, batch_size, max_wire_version):
cmd = SON([("find", coll)])
if "$query" in filter_with_modifiers:
cmd.update([(cls._MODIFIERS[key], val) if key in cls._MODIFIERS else (key, val)
for key, val in filter_with_modifiers.items()])
if max_wire_version >= 7: # MongoDB 4.0+
cmd.pop('snapshot', None)
else:
cmd["filter"] = filter_with_modifiers

if projection:
cmd["projection"] = projection
if skip:
cmd["skip"] = skip
if limit:
cmd["limit"] = abs(limit)
if limit < 0:
cmd["singleBatch"] = True
cmd["batchSize"] = abs(limit)
if batch_size:
cmd["batchSize"] = batch_size

if '$explain' in filter_with_modifiers:
cmd.pop('$explain')
cmd = SON([('explain', cmd)])

return cmd

def __send_find_command(self, protocol, filter, projection, skip, limit, batch_size, as_class, flags, deadline):
codec_options = self.codec_options
if as_class is not None:
codec_options = codec_options._replace(document_class=as_class)

def after_reply(result, this_func, fetched=0):
try:
check_deadline(deadline)
except Exception:
cursor_id = result.get("cursor", {}).get("id")
if cursor_id:
kill = SON([
("killCursors", self.name),
("cursors", [cursor_id])
])
self.database.command(kill)
raise

if "cursor" not in result:
return [result], defer.succeed(([], None))
cursor = result["cursor"]

docs_key = "firstBatch"
if "nextBatch" in cursor:
docs_key = "nextBatch"

docs_count = len(cursor[docs_key])
if limit > 0:
docs_count = min(docs_count, limit - fetched)
fetched += docs_count
out = cursor[docs_key][:docs_count]

as_class = kwargs.get("as_class")
proto = self._database.connection.getprotocol()
if cursor["id"]:
if limit == 0:
to_fetch = 0 # no limit
if batch_size:
to_fetch = batch_size
elif limit < 0:
# We won't actually get here because MongoDB won't
# create a cursor when limit < 0
to_fetch = None
else:
to_fetch = limit - fetched
if to_fetch <= 0:
to_fetch = None # close cursor
elif batch_size:
to_fetch = min(batch_size, to_fetch)

def after_connection(protocol):
flags = kwargs.get("flags", 0)
if to_fetch is None:
# FIXME: extract this to a function
kill = SON([
("killCursors", self.name),
("cursors", [cursor["id"]])
])
self.database.command(kill)
return out, defer.succeed(([], None))

check_deadline(kwargs.pop("_deadline", None))
# FIXME: extract this to a function
get_more = SON([
("getMore", cursor["id"]),
("collection", self.name),
])
if batch_size:
get_more["batchSize"] = batch_size
next_reply = self.database._send_command_to_proto(protocol, get_more, codec_options=codec_options, flags=flags)
next_reply.addCallback(this_func, this_func, fetched)
return out, next_reply

if batch_size and limit:
n_to_return = min(batch_size,limit)
elif batch_size:
n_to_return = batch_size
else:
n_to_return = limit
return out, defer.succeed(([], None))

query = Query(flags=flags, collection=str(self),
n_to_skip=skip, n_to_return=n_to_return,
query=filter, fields=projection)
cmd = self._gen_find_command(self.name, filter, projection, skip, limit, batch_size, protocol.max_wire_version)
return self.database._send_command_to_proto(protocol, cmd, codec_options=codec_options, flags=flags)\
.addCallback(after_reply, after_reply)

deferred_query = protocol.send_QUERY(query)
deferred_query.addCallback(after_reply, protocol, after_reply)
return deferred_query

def __send_legacy_find(self, protocol, filter, projection, skip, limit, batch_size, as_class, deadline, kwargs):
# this_func argument is just a reference to after_reply function itself.
# after_reply can reference to itself directly but this will create a circular
# reference between closure and function object which will add unnecessary
# work for GC.
def after_reply(reply, protocol, this_func, fetched=0):
try:
check_deadline(deadline)
except Exception:
if reply.cursor_id:
protocol.send_KILL_CURSORS(KillCursors(cursors=[reply.cursor_id]))
raise

documents = reply.documents
docs_count = len(documents)
Expand All @@ -500,21 +583,21 @@ def after_reply(reply, protocol, this_func, fetched=0):
if reply.cursor_id:
# please note that this will not be the case if batch_size = 1
# it is documented (parameter numberToReturn for OP_QUERY)
# https://docs.mongodb.com/manual/reference/mongodb-wire-protocol/#wire-op-query
# https://docs.mongodb.com/manual/reference/mongodb-wire-protocol/#wire-op-query
if limit == 0:
to_fetch = 0 # no limit
if batch_size:
to_fetch = batch_size
elif limit < 0:
# We won't actually get here because MongoDB won't
# create cursor when limit < 0
# create a cursor when limit < 0
to_fetch = None
else:
to_fetch = limit - fetched
if to_fetch <= 0:
to_fetch = None # close cursor
elif batch_size:
to_fetch = min(batch_size,to_fetch)
to_fetch = min(batch_size, to_fetch)

if to_fetch is None:
protocol.send_KILL_CURSORS(KillCursors(cursors=[reply.cursor_id]))
Expand All @@ -529,6 +612,56 @@ def after_reply(reply, protocol, this_func, fetched=0):

return out, defer.succeed(([], None))

flags = kwargs.get("flags", 0)

if batch_size and limit:
n_to_return = min(batch_size, limit)
elif batch_size:
n_to_return = batch_size
else:
n_to_return = limit

query = Query(flags=flags, collection=str(self),
n_to_skip=skip, n_to_return=n_to_return,
query=filter, fields=projection)

deferred_query = protocol.send_QUERY(query)
deferred_query.addCallback(after_reply, protocol, after_reply)
return deferred_query


def __real_find_with_cursor(self, filter=None, projection=None, skip=0, limit=0, sort=None, batch_size=0, **kwargs):

if filter is None:
filter = SON()

if not isinstance(filter, dict):
raise TypeError("TxMongo: filter must be an instance of dict.")
if not isinstance(projection, (dict, list)) and projection is not None:
raise TypeError("TxMongo: projection must be an instance of dict or list.")
if not isinstance(skip, int):
raise TypeError("TxMongo: skip must be an instance of int.")
if not isinstance(limit, int):
raise TypeError("TxMongo: limit must be an instance of int.")
if not isinstance(batch_size, int):
raise TypeError("TxMongo: batch_size must be an instance of int.")

projection = self._normalize_fields_projection(projection)

filter = self.__apply_find_filter(filter, sort)

as_class = kwargs.get("as_class")
proto = self._database.connection.getprotocol()

deadline = kwargs.pop("_deadline", None)

def after_connection(protocol):
check_deadline(deadline)
if protocol.max_wire_version < 4:
return self.__send_legacy_find(protocol, filter, projection, skip, limit, batch_size, as_class, deadline, kwargs)
return self.__send_find_command(protocol, filter, projection, skip, limit, batch_size, as_class,
kwargs.get("flags", 0), deadline)

proto.addCallback(after_connection)
return proto

Expand Down
20 changes: 10 additions & 10 deletions txmongo/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,9 @@ def _initializeProto(self, proto):
slaveok = uri_options['readpreference'] not in _PRIMARY_READ_PREFERENCES

try:
if not slaveok:
# Update our server configuration. This may disconnect if the node
# is not a master.
yield self.configure(proto)
# Update our server configuration. This may disconnect if the node
# is not a master and slaveok is not set
yield self.configure(proto, slaveok)

yield self._auth_proto(proto)
self.setInstance(instance=proto)
Expand All @@ -77,7 +76,7 @@ def __send_ismaster(proto, **kwargs):
return proto.send_QUERY(query)

@defer.inlineCallbacks
def configure(self, proto):
def configure(self, proto, slaveok):
"""
Configures the protocol using the information gathered from the
remote Mongo instance. Such information may contain the max
Expand Down Expand Up @@ -134,11 +133,12 @@ def configure(self, proto):
if host not in self.__allnodes:
self.__allnodes.append(host)

# Check if this node is the master.
ismaster = config.get("ismaster")
if not ismaster:
msg = "TxMongo: MongoDB host `%s` is not master." % config.get('me')
raise AutoReconnect(msg)
if not slaveok:
# Check if this node is the master.
ismaster = config.get("ismaster")
if not ismaster:
msg = "TxMongo: MongoDB host `%s` is not master." % config.get('me')
raise AutoReconnect(msg)

def clientConnectionFailed(self, connector, reason):
self.instance = None
Expand Down
Loading

0 comments on commit b016f7e

Please sign in to comment.