diff --git a/tests/test_replicaset.py b/tests/test_replicaset.py index 66e63a0..10817fa 100644 --- a/tests/test_replicaset.py +++ b/tests/test_replicaset.py @@ -15,18 +15,15 @@ from __future__ import absolute_import, division -from pymongo.errors import OperationFailure, AutoReconnect, ConfigurationError +from bson import SON +from pymongo.errors import OperationFailure, AutoReconnect, ConfigurationError, NetworkTimeout from twisted.trial import unittest -from twisted.python.compat import _PY3 from twisted.internet import defer, reactor from txmongo.connection import MongoConnection, ConnectionPool, _Connection from txmongo.protocol import QUERY_SLAVE_OK, MongoProtocol from .mongod import Mongod -if _PY3: - from twisted.python.compat import xrange - class TestReplicaSet(unittest.TestCase): @@ -36,7 +33,7 @@ class TestReplicaSet(unittest.TestCase): rsconfig = { "_id": rsname, "members": [ - {"_id": i, "host": "localhost:{0}".format(port) } + {"_id": i, "host": "localhost:{0}".format(port)} for i, port in enumerate(ports) ] } @@ -71,7 +68,7 @@ def setUp(self): ready = False n_tries = int(self.__init_timeout / self.__ping_interval) - for i in xrange(n_tries): + for i in range(n_tries): yield self.__sleep(self.__ping_interval) # My practice shows that we need to query both ismaster and replSetGetStatus @@ -83,7 +80,7 @@ def setUp(self): ismaster, replstatus = yield defer.gatherResults([ismaster_req, replstatus_req]) initialized = replstatus["ok"] - ok_states = set(["PRIMARY", "SECONDARY"]) + ok_states = {"PRIMARY", "SECONDARY"} states_ready = all(m["stateStr"] in ok_states for m in replstatus.get("members", [])) ready = initialized and ismaster["ismaster"] and states_ready @@ -92,7 +89,8 @@ def setUp(self): if not ready: yield self.tearDown() - raise Exception("ReplicaSet initialization took more than {0}s".format(self.__init_timeout)) + raise Exception("ReplicaSet initialization took more than {0}s".format( + self.__init_timeout)) yield master.disconnect() @@ -144,18 +142,55 @@ def test_AutoReconnect(self): uri = "mongodb://localhost:{0}/?w={1}".format(self.ports[0], len(self.ports)) conn = ConnectionPool(uri) - yield conn.db.coll.insert({'x': 42}, safe = True) + yield conn.db.coll.insert({'x': 42}, safe=True) yield self.__mongod[0].stop() while True: try: result = yield conn.db.coll.find_one() + self.assertEqual(result['x'], 42) + break + except AutoReconnect: + pass + + finally: + yield conn.disconnect() + self.flushLoggedErrors(AutoReconnect) + + @defer.inlineCallbacks + def test_AutoReconnect_from_primary_step_down(self): + self.patch(_Connection, 'maxDelay', 5) + uri = "mongodb://localhost:{0}/?w={1}".format(self.ports[0], len(self.ports)) + conn = ConnectionPool(uri) + + # this will force primary to step down, triggering an AutoReconnect that bubbles up + # through the connection pool to the client + command = conn.admin.command(SON([('replSetStepDown', 86400), ('force', 1)])) + self.assertFailure(command, AutoReconnect) + + yield conn.disconnect() + + @defer.inlineCallbacks + def test_NetworkTimeout_with_deadline(self): + self.patch(_Connection, 'maxDelay', 5) + + try: + uri = "mongodb://localhost:{0}/?w={1}".format(self.ports[0], len(self.ports)) + conn = ConnectionPool(uri, deadline=2, initial_delay=3) + + yield conn.db.coll.insert({'x': 42}, safe=True) + + yield self.__mongod[0].stop() + + while True: + try: + deferred_call = conn.db.coll.find_one() + self.assertFailure(deferred_call, NetworkTimeout) break except AutoReconnect: pass - self.assertEqual(result['x'], 42) finally: yield conn.disconnect() self.flushLoggedErrors(AutoReconnect) diff --git a/txmongo/collection.py b/txmongo/collection.py index 5a9795c..4e34c44 100644 --- a/txmongo/collection.py +++ b/txmongo/collection.py @@ -7,7 +7,7 @@ from bson import BSON, ObjectId from bson.code import Code from bson.son import SON -from pymongo.errors import InvalidName +from pymongo.errors import InvalidName, NetworkTimeout from pymongo.helpers import _check_write_command_response from pymongo.results import InsertOneResult, InsertManyResult, UpdateResult, \ DeleteResult @@ -19,6 +19,7 @@ Query, Getmore, Insert, Update, Delete, KillCursors, INSERT_CONTINUE_ON_ERROR from txmongo import filter as qf from twisted.internet import defer +from twisted.python import log from twisted.python.compat import unicode, comparable @@ -132,7 +133,6 @@ def options(self): del options["create"] defer.returnValue(options) - @defer.inlineCallbacks def find(self, spec=None, skip=0, limit=0, fields=None, filter=None, cursor=False, **kwargs): docs, dfr = yield self.find_with_cursor(spec=spec, skip=skip, limit=limit, @@ -148,12 +148,13 @@ def find(self, spec=None, skip=0, limit=0, fields=None, filter=None, cursor=Fals defer.returnValue(result) - def __apply_find_filter(self, spec, filter): - if filter: + @staticmethod + def __apply_find_filter(spec, c_filter): + if c_filter: if "query" not in spec: spec = {"$query": spec} - for k, v in filter.items(): + for k, v in c_filter.items(): if isinstance(v, (list, tuple)): spec['$' + k] = dict(v) else: @@ -187,19 +188,19 @@ def find_with_cursor(self, spec=None, skip=0, limit=0, fields=None, filter=None, spec = self.__apply_find_filter(spec, filter) as_class = kwargs.get("as_class", dict) - deferred_protocol = self._database.connection.getprotocol() + proto = self._database.connection.getprotocol() - def after_connection(proto): + def after_connection(protocol): flags = kwargs.get("flags", 0) query = Query(flags=flags, collection=str(self), n_to_skip=skip, n_to_return=limit, query=spec, fields=fields) - deferred_query = proto.send_QUERY(query) - deferred_query.addCallback(after_reply, proto) + deferred_query = protocol.send_QUERY(query) + deferred_query.addCallback(after_reply, protocol) return deferred_query - def after_reply(reply, proto, fetched=0): + def after_reply(reply, protocol, fetched=0): documents = reply.documents docs_count = len(documents) if limit > 0: @@ -222,21 +223,20 @@ def after_reply(reply, proto, fetched=0): to_fetch = None # close cursor if to_fetch is None: - proto.send_KILL_CURSORS(KillCursors(cursors=[reply.cursor_id])) + protocol.send_KILL_CURSORS(KillCursors(cursors=[reply.cursor_id])) return out, defer.succeed(([], None)) - next_reply = proto.send_GETMORE(Getmore( + next_reply = protocol.send_GETMORE(Getmore( collection=str(self), cursor_id=reply.cursor_id, n_to_return=to_fetch )) - next_reply.addCallback(after_reply, proto, fetched) + next_reply.addCallback(after_reply, protocol, fetched) return out, next_reply return out, defer.succeed(([], None)) - - deferred_protocol.addCallback(after_connection) - return deferred_protocol + proto.addCallback(after_connection) + return proto @defer.inlineCallbacks def find_one(self, spec=None, fields=None, **kwargs): @@ -245,7 +245,6 @@ def find_one(self, spec=None, fields=None, **kwargs): result = yield self.find(spec=spec, limit=1, fields=fields, **kwargs) defer.returnValue(result[0] if result else None) - @defer.inlineCallbacks def count(self, spec=None, fields=None): fields = self._normalize_fields_projection(fields) @@ -283,13 +282,14 @@ def filemd5(self, spec): result = yield self._database.command("filemd5", spec, root=self._collection_name) defer.returnValue(result.get("md5")) - def _get_write_concern(self, safe=None, **wc_options): from_opts = WriteConcern(**wc_options) if from_opts.document: return from_opts - if safe == True: + if safe is None: + return self.write_concern + elif safe: if self.write_concern.acknowledged: return self.write_concern else: @@ -297,11 +297,8 @@ def _get_write_concern(self, safe=None, **wc_options): # In this case safe=True must issue getLastError without args # even if connection-level write concern was unacknowledged return WriteConcern() - elif safe == False: - return WriteConcern(w=0) - - return self.write_concern + return WriteConcern(w=0) @defer.inlineCallbacks def insert(self, docs, safe=None, flags=0, **kwargs): @@ -324,7 +321,11 @@ def insert(self, docs, safe=None, flags=0, **kwargs): docs = [BSON.encode(d) for d in docs] insert = Insert(flags=flags, collection=str(self), documents=docs) - proto = yield self._database.connection.getprotocol() + try: + proto = yield self._database.connection.getprotocol() + except NetworkTimeout as e: # prevent insertion behind the back after a timeout + log.err(str(e)) + defer.returnValue(None) proto.send_INSERT(insert) @@ -366,7 +367,6 @@ def insert_many(self, documents, ordered=True): inserted_ids = yield self._insert_one_or_many(documents, ordered) defer.returnValue(InsertManyResult(inserted_ids, self.write_concern.acknowledged)) - @defer.inlineCallbacks def update(self, spec, document, upsert=False, multi=False, safe=None, flags=0, **kwargs): if not isinstance(spec, dict): @@ -385,7 +385,12 @@ def update(self, spec, document, upsert=False, multi=False, safe=None, flags=0, document = BSON.encode(document) update = Update(flags=flags, collection=str(self), selector=spec, update=document) - proto = yield self._database.connection.getprotocol() + + try: + proto = yield self._database.connection.getprotocol() + except NetworkTimeout as e: # prevent update behind the back after a timeout + log.err(str(e)) + defer.returnValue(None) proto.send_UPDATE(update) @@ -394,7 +399,6 @@ def update(self, spec, document, upsert=False, multi=False, safe=None, flags=0, ret = yield proto.get_last_error(str(self._database), **write_concern.document) defer.returnValue(ret) - @defer.inlineCallbacks def _update(self, filter, update, upsert, multi): validate_is_mapping("filter", filter) @@ -420,7 +424,6 @@ def _update(self, filter, update, upsert, multi): defer.returnValue(raw_response) - @defer.inlineCallbacks def update_one(self, filter, update, upsert=False): validate_ok_for_update(update) @@ -502,7 +505,6 @@ def delete_many(self, filter): raw_response = yield self._delete(filter, multi=True) defer.returnValue(DeleteResult(raw_response, self.write_concern.acknowledged)) - def drop(self, **kwargs): return self._database.drop_collection(self._collection_name) @@ -623,7 +625,6 @@ def find_and_modify(self, query=None, update=None, upsert=False, **kwargs): raise ValueError("Unexpected Error: %s" % (result,)) defer.returnValue(result.get("value")) - # Distinct findAndModify utility method is needed because traditional # find_and_modify() accepts `sort` kwarg as dict and passes it to # MongoDB command without conversion. But in find_one_and_* diff --git a/txmongo/connection.py b/txmongo/connection.py index 7963f64..9df1c9a 100755 --- a/txmongo/connection.py +++ b/txmongo/connection.py @@ -4,16 +4,15 @@ from __future__ import absolute_import, division -from pymongo.errors import AutoReconnect, ConfigurationError, OperationFailure +from pymongo.errors import AutoReconnect, ConfigurationError, OperationFailure, NetworkTimeout from pymongo.uri_parser import parse_uri from pymongo.read_preferences import ReadPreference from pymongo.write_concern import WriteConcern - +from time import time from twisted.internet import defer, reactor, task from twisted.internet.protocol import ReconnectingClientFactory from twisted.python import log from twisted.python.compat import StringType - from txmongo.database import Database from txmongo.protocol import MongoProtocol, Query @@ -245,6 +244,8 @@ def __init__(self, uri="mongodb://127.0.0.1:27017", pool_size=1, ssl_context_fac assert isinstance(pool_size, int) assert pool_size >= 1 + self.__deadline = kwargs.get('deadline', None) + if not uri.startswith("mongodb://"): uri = "mongodb://" + uri @@ -255,9 +256,9 @@ def __init__(self, uri="mongodb://127.0.0.1:27017", pool_size=1, ssl_context_fac wc_options = dict((k, v) for k, v in wc_options.items() if k in self.__wc_possible_options) self.__write_concern = WriteConcern(**wc_options) - initial_timeout = kwargs.get('timeout', 1.0) + retry_delay = kwargs.get('retry_delay', 1.0) self.__pool_size = pool_size - self.__pool = [_Connection(self, self.__uri, i, initial_timeout) for i in range(pool_size)] + self.__pool = [_Connection(self, self.__uri, i, retry_delay) for i in range(pool_size)] if self.__uri['database'] and self.__uri['username'] and self.__uri['password']: self.authenticate(self.__uri['database'], self.__uri['username'], @@ -266,13 +267,13 @@ def __init__(self, uri="mongodb://127.0.0.1:27017", pool_size=1, ssl_context_fac host, port = self.__uri['nodelist'][0] - connection_timeout = kwargs.get('timeout', 30) + initial_delay = kwargs.get('retry_delay', 30) for factory in self.__pool: if ssl_context_factory: factory.connector = reactor.connectSSL( - host, port, factory, ssl_context_factory, connection_timeout) + host, port, factory, ssl_context_factory, initial_delay) else: - factory.connector = reactor.connectTCP(host, port, factory, connection_timeout) + factory.connector = reactor.connectTCP(host, port, factory, initial_delay) @property def write_concern(self): @@ -323,9 +324,12 @@ def authenticate(self, database, username, password, mechanism="DEFAULT"): except defer.FirstError as e: raise e.subFailure.value - @defer.inlineCallbacks def getprotocol(self): + # Set our deadline watchdog + if self.__deadline is not None: + start = time() + # Get the next protocol available for communication in the pool. connection = self.__pool[self.__index] self.__index = (self.__index + 1) % self.__pool_size @@ -336,6 +340,11 @@ def getprotocol(self): # Wait for the connection to connection. yield connection.notifyReady() + + # Handle deadline + if self.__deadline is not None and time() - start > self.__deadline: + raise NetworkTimeout("MongoDB timeout of {0}s reached.".format(self.__deadline)) + defer.returnValue(connection.instance) @property diff --git a/txmongo/protocol.py b/txmongo/protocol.py index 0ec264c..011ba8c 100644 --- a/txmongo/protocol.py +++ b/txmongo/protocol.py @@ -20,7 +20,8 @@ from hashlib import sha1 import hmac from pymongo import auth -from pymongo.errors import AutoReconnect, ConnectionFailure, DuplicateKeyError, OperationFailure +from pymongo.errors import AutoReconnect, ConnectionFailure, DuplicateKeyError, OperationFailure, \ + NotMasterError from random import SystemRandom import struct from twisted.internet import defer, protocol, error @@ -364,7 +365,7 @@ def handle_REPLY(self, request): msg = doc.get("$err", "Unknown error") fail_conn = False if code == 13435: - err = AutoReconnect(msg) + err = NotMasterError(msg) fail_conn = True else: err = OperationFailure(msg, code)