From aa72878a0239d8979ea987446098a090719cfd71 Mon Sep 17 00:00:00 2001 From: Ilya Skriblovsky Date: Thu, 21 May 2015 23:21:12 +0300 Subject: [PATCH] Failover fixed and covered with unittests --- tests/mongod.py | 21 +++++- tests/test_auth.py | 6 +- tests/test_replicaset.py | 144 +++++++++++++++++++++++++++++++++++++++ txmongo/connection.py | 64 +++++++++-------- txmongo/protocol.py | 14 +++- 5 files changed, 212 insertions(+), 37 deletions(-) create mode 100644 tests/test_replicaset.py diff --git a/tests/mongod.py b/tests/mongod.py index 7cfab36..6e7f368 100644 --- a/tests/mongod.py +++ b/tests/mongod.py @@ -1,3 +1,7 @@ +import os +import tempfile +import shutil + from twisted.internet import defer, reactor from twisted.internet.error import ProcessDone @@ -9,30 +13,35 @@ class Mongod(object): # so leaving this for now success_message = "waiting for connections on port" - def __init__(self, dbpath, port=27017, auth=False): + def __init__(self, port=27017, auth=False, replset = None): self.__proc = None self.__notify_waiting = [] self.__notify_stop = [] self.__output = '' self.__end_reason = None - self.dbpath = dbpath + self.__datadir = None + self.port = port self.auth = auth + self.replset = replset def start(self): + self.__datadir = tempfile.mkdtemp() + d = defer.Deferred() self.__notify_waiting.append(d) args = ["mongod", "--port", str(self.port), - "--dbpath", str(self.dbpath), + "--dbpath", self.__datadir, "--noprealloc", "--nojournal", "--smallfiles", "--nssize", "1", "--nohttpinterface", ] if self.auth: args.append("--auth") + if self.replset: args.extend(["--replSet", self.replset]) self.__proc = reactor.spawnProcess(self, "mongod", args) return d @@ -71,3 +80,9 @@ def processEnded(self, reason): d.callback(None) else: d.errback(reason) + + if self.__datadir: + shutil.rmtree(self.__datadir) + + + def output(self): return self.__output diff --git a/tests/test_auth.py b/tests/test_auth.py index d826ca3..4bdb77b 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -99,10 +99,7 @@ def createDBUsers(self): @defer.inlineCallbacks def setUp(self): - self.__datadir = self.mktemp() - os.makedirs(self.__datadir) - - self.__mongod = Mongod(self.__datadir, port=mongo_port, auth=True) + self.__mongod = Mongod(port=mongo_port, auth=True) yield self.__mongod.start() yield self.createUserAdmin() @@ -124,7 +121,6 @@ def tearDown(self): yield conn.disconnect() finally: yield self.__mongod.stop() - shutil.rmtree(self.__datadir) @defer.inlineCallbacks def test_AuthConnectionPool(self): diff --git a/tests/test_replicaset.py b/tests/test_replicaset.py new file mode 100644 index 0000000..d662101 --- /dev/null +++ b/tests/test_replicaset.py @@ -0,0 +1,144 @@ +# coding: utf-8 +# Copyright 2010 Mark L. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from bson.son import SON +from pymongo.errors import OperationFailure, AutoReconnect +from twisted.trial import unittest +from twisted.internet import defer, base, reactor +from txmongo.connection import MongoConnection, ConnectionPool, _Connection +from txmongo.protocol import QUERY_SLAVE_OK + +from mongod import Mongod + +# base.DelayedCall.debug = True + + +class TestReplicaSet(unittest.TestCase): + + ports = [37017, 37018, 37019] + rsname = "rs1" + + rsconfig = { + "_id": rsname, + "members": [ + {"_id": i, "host": "localhost:{0}".format(port) } + for i, port in enumerate(ports) + ] + } + # We assume first member to be master + rsconfig["members"][0]["priority"] = 2 + + def __sleep(self, delay): + d = defer.Deferred() + reactor.callLater(delay, d.callback, None) + return d + + @defer.inlineCallbacks + def setUp(self): + self.__mongod = [Mongod(port=p, replset=self.rsname) for p in self.ports] + yield defer.gatherResults([mongo.start() for mongo in self.__mongod]) + + master_uri = "mongodb://localhost:{0}/?readPreference=secondaryPreferred".format(self.ports[0]) + master = ConnectionPool(master_uri) + yield master.admin["$cmd"].find_one({"replSetInitiate": self.rsconfig}) + + ready = False + for i in xrange(30): + yield self.__sleep(0.5) + + # My practice shows that we need to query both ismaster and replSetGetStatus + # to be sure that replica set is up and running, primary is elected and all + # secondaries are in sync and ready to became new primary + + ismaster_req = master.admin["$cmd"].find_one({"ismaster": 1}) + replstatus_req = master.admin["$cmd"].find_one({"replSetGetStatus": 1}) + ismaster, replstatus = yield defer.gatherResults([ismaster_req, replstatus_req]) + + startup = any(m["stateStr"].startswith("STARTUP") for m in replstatus["members"]) + ready = ismaster["ismaster"] and not startup + + if ready: + break + + if not ready: + yield self.tearDown() + raise Exception("ReplicaSet initialization took more than 15s") + + yield master.disconnect() + + + @defer.inlineCallbacks + def tearDown(self): + yield defer.gatherResults([mongo.stop() for mongo in self.__mongod]) + + + @defer.inlineCallbacks + def test_WriteToMaster(self): + conn = MongoConnection("localhost", self.ports[0]) + try: + coll = conn.db.coll + yield coll.insert({'x': 42}, safe=True) + result = yield coll.find_one() + self.assertEqual(result['x'], 42) + finally: + yield conn.disconnect() + + @defer.inlineCallbacks + def test_SlaveOk(self): + uri = "mongodb://localhost:{0}/?readPreference=secondaryPreferred".format(self.ports[1]) + conn = ConnectionPool(uri) + try: + empty = yield conn.db.coll.find(flags=QUERY_SLAVE_OK) + self.assertEqual(empty, []) + + yield self.assertFailure(conn.db.coll.insert({'x': 42}), OperationFailure) + finally: + yield conn.disconnect() + + + @defer.inlineCallbacks + def test_SwitchToMasterOnConnect(self): + # Reverse hosts order + try: + conn = MongoConnection("localhost", self.ports[1]) + result = yield conn.db.coll.find({'x': 42}) + self.assertEqual(result, []) + finally: + yield conn.disconnect() + + # txmongo will do log.err() for AutoReconnects + self.flushLoggedErrors(AutoReconnect) + + @defer.inlineCallbacks + def test_AutoReconnect(self): + self.patch(_Connection, 'maxDelay', 5) + + try: + uri = "mongodb://localhost:{0}/?w={1}".format(self.ports[0], len(self.ports)) + conn = ConnectionPool(uri) + + yield conn.db.coll.insert({'x': 42}, safe = True) + + yield self.__mongod[0].stop() + + try: + result = yield conn.db.coll.find_one() + except AutoReconnect: + result = yield conn.db.coll.find_one() + + self.assertEqual(result['x'], 42) + finally: + yield conn.disconnect() + self.flushLoggedErrors(AutoReconnect) diff --git a/txmongo/connection.py b/txmongo/connection.py index e4355b2..7076db7 100755 --- a/txmongo/connection.py +++ b/txmongo/connection.py @@ -4,6 +4,7 @@ from pymongo.errors import AutoReconnect, ConfigurationError, OperationFailure from pymongo.uri_parser import parse_uri +from pymongo.read_preferences import ReadPreference from twisted.internet import defer, reactor, task from twisted.internet.protocol import ReconnectingClientFactory @@ -15,7 +16,7 @@ class _Connection(ReconnectingClientFactory): __notify_ready = None - __discovered = None + __allnodes = None __index = -1 __uri = None __conf_loop = None @@ -26,7 +27,7 @@ class _Connection(ReconnectingClientFactory): maxDelay = 60 def __init__(self, pool, uri, id): - self.__discovered = [] + self.__allnodes = list(uri["nodelist"]) self.__notify_ready = [] self.__pool = pool self.__uri = uri @@ -39,18 +40,30 @@ def __init__(self, pool, uri, id): def buildProtocol(self, addr): # Build the protocol. p = ReconnectingClientFactory.buildProtocol(self, addr) + self._initializeProto(p) + return p - ready_deferred = p.connectionReady() + @defer.inlineCallbacks + def _initializeProto(self, proto): + yield proto.connectionReady() + self.resetDelay() - if not self.uri['options'].get('slaveok', False): - # Update our server configuration. This may disconnect if the node - # is not a master. - ready_deferred.addCallback(lambda _: self.configure(p)) + uri_options = self.uri['options'] + slaveok = uri_options.get('slaveok', False) + if 'readpreference' in uri_options: + slaveok = uri_options['readpreference'] not in (ReadPreference.PRIMARY.mode, + ReadPreference.PRIMARY_PREFERRED.mode) - ready_deferred\ - .addCallback(lambda _: self._auth_proto(p))\ - .addBoth(lambda _: self.setInstance(instance=p)) - return p + try: + if not slaveok: + # Update our server configuration. This may disconnect if the node + # is not a master. + yield self.configure(proto) + + yield self._auth_proto(proto) + self.setInstance(instance=proto) + except Exception as e: + proto.fail(e) def configure(self, proto): """ @@ -73,8 +86,7 @@ def _configureCallback(self, reply, proto): """ # Make sure we got a result document. if len(reply.documents) != 1: - proto.fail(OperationFailure("Invalid document length.")) - return + raise OperationFailure("Invalid document length.") # Get the configuration document from the reply. config = reply.documents[0].decode() @@ -83,8 +95,7 @@ def _configureCallback(self, reply, proto): if not config.get("ok"): code = config.get("code") msg = config.get("err", "Unknown error") - proto.fail(OperationFailure(msg, code)) - return + raise OperationFailure(msg, code) # Check that the replicaSet matches. set_name = config.get("setName") @@ -92,9 +103,7 @@ def _configureCallback(self, reply, proto): if expected_set_name and (expected_set_name != set_name): # Log the invalid replica set failure. msg = "Mongo instance does not match requested replicaSet." - reason = ConfigurationError(msg) - proto.fail(reason) - return + raise ConfigurationError(msg) # Track max bson object size limit. max_bson_size = config.get("maxBsonObjectSize") @@ -107,22 +116,20 @@ def _configureCallback(self, reply, proto): # Track the other hosts in the replica set. hosts = config.get("hosts") if isinstance(hosts, list) and hosts: - hostaddrs = [] for host in hosts: if ':' not in host: host = (host, 27017) else: host = host.split(':', 1) host[1] = int(host[1]) - hostaddrs.append(host) - self.__discovered = hostaddrs + host = tuple(host) + if host not in self.__allnodes: + self.__allnodes.append(host) # Check if this node is the master. ismaster = config.get("ismaster") if not ismaster: - reason = AutoReconnect("not master") - proto.fail(reason) - return + raise AutoReconnect("not master") def clientConnectionFailed(self, connector, reason): self.instance = None @@ -171,12 +178,11 @@ def retryNextHost(self, connector=None): delay = False self.__index += 1 - all_nodes = list(self.uri["nodelist"]) + list(self.__discovered) - if self.__index >= len(all_nodes): + if self.__index >= len(self.__allnodes): self.__index = 0 delay = True - connector.host, connector.port = all_nodes[self.__index] + connector.host, connector.port = self.__allnodes[self.__index] if delay: self.retry(connector) @@ -184,6 +190,10 @@ def retryNextHost(self, connector=None): connector.connect() def setInstance(self, instance=None, reason=None): + if instance == self.instance: + # Should not fail deferreds from __notify_ready if setInstance(None) + # called when instance is already None + return self.instance = instance deferreds, self.__notify_ready = self.__notify_ready, [] if deferreds: diff --git a/txmongo/protocol.py b/txmongo/protocol.py index e194a3d..d2de79d 100644 --- a/txmongo/protocol.py +++ b/txmongo/protocol.py @@ -309,14 +309,24 @@ def connectionMade(self): df.callback(self) def connectionLost(self, reason=connectionDone): + # We need to clear factory.instance before failing deferreds + # because client code might immediately re-issue query when + # it catches AutoReconnect, so we must invalidate current + # connection before. Factory.clientConnectionFailed() is called + # too late. + self.factory.setInstance(None, reason) + + autoreconnect = AutoReconnect() + if self.__deferreds: deferreds, self.__deferreds = self.__deferreds, {} for df in deferreds.itervalues(): - df.errback(reason) + df.errback(autoreconnect) deferreds, self.__connection_ready = self.__connection_ready, [] if deferreds: for df in deferreds: - df.errback(reason) + df.errback(autoreconnect) + protocol.Protocol.connectionLost(self, reason) def connectionReady(self):