Skip to content

Commit

Permalink
Merge 6f47c6f into fdfec71
Browse files Browse the repository at this point in the history
  • Loading branch information
psi29a committed Sep 18, 2015
2 parents fdfec71 + 6f47c6f commit 1483b79
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 52 deletions.
57 changes: 46 additions & 11 deletions tests/test_replicaset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,15 @@

from __future__ import absolute_import, division

from pymongo.errors import OperationFailure, AutoReconnect, ConfigurationError
from bson import SON
from pymongo.errors import OperationFailure, AutoReconnect, ConfigurationError, NetworkTimeout
from twisted.trial import unittest
from twisted.python.compat import _PY3
from twisted.internet import defer, reactor
from txmongo.connection import MongoConnection, ConnectionPool, _Connection
from txmongo.protocol import QUERY_SLAVE_OK, MongoProtocol

from .mongod import Mongod

if _PY3:
from twisted.python.compat import xrange


class TestReplicaSet(unittest.TestCase):

Expand All @@ -36,7 +33,7 @@ class TestReplicaSet(unittest.TestCase):
rsconfig = {
"_id": rsname,
"members": [
{"_id": i, "host": "localhost:{0}".format(port) }
{"_id": i, "host": "localhost:{0}".format(port)}
for i, port in enumerate(ports)
]
}
Expand Down Expand Up @@ -71,7 +68,7 @@ def setUp(self):

ready = False
n_tries = int(self.__init_timeout / self.__ping_interval)
for i in xrange(n_tries):
for i in range(n_tries):
yield self.__sleep(self.__ping_interval)

# My practice shows that we need to query both ismaster and replSetGetStatus
Expand All @@ -83,7 +80,7 @@ def setUp(self):
ismaster, replstatus = yield defer.gatherResults([ismaster_req, replstatus_req])

initialized = replstatus["ok"]
ok_states = set(["PRIMARY", "SECONDARY"])
ok_states = {"PRIMARY", "SECONDARY"}
states_ready = all(m["stateStr"] in ok_states for m in replstatus.get("members", []))
ready = initialized and ismaster["ismaster"] and states_ready

Expand All @@ -92,7 +89,8 @@ def setUp(self):

if not ready:
yield self.tearDown()
raise Exception("ReplicaSet initialization took more than {0}s".format(self.__init_timeout))
raise Exception("ReplicaSet initialization took more than {0}s".format(
self.__init_timeout))

yield master.disconnect()

Expand Down Expand Up @@ -144,18 +142,55 @@ def test_AutoReconnect(self):
uri = "mongodb://localhost:{0}/?w={1}".format(self.ports[0], len(self.ports))
conn = ConnectionPool(uri)

yield conn.db.coll.insert({'x': 42}, safe = True)
yield conn.db.coll.insert({'x': 42}, safe=True)

yield self.__mongod[0].stop()

while True:
try:
result = yield conn.db.coll.find_one()
self.assertEqual(result['x'], 42)
break
except AutoReconnect:
pass

finally:
yield conn.disconnect()
self.flushLoggedErrors(AutoReconnect)

@defer.inlineCallbacks
def test_AutoReconnect_from_primary_step_down(self):
self.patch(_Connection, 'maxDelay', 5)
uri = "mongodb://localhost:{0}/?w={1}".format(self.ports[0], len(self.ports))
conn = ConnectionPool(uri)

# this will force primary to step down, triggering an AutoReconnect that bubbles up
# through the connection pool to the client
command = conn.admin.command(SON([('replSetStepDown', 86400), ('force', 1)]))
self.assertFailure(command, AutoReconnect)

yield conn.disconnect()

@defer.inlineCallbacks
def test_NetworkTimeout_with_deadline(self):
self.patch(_Connection, 'maxDelay', 5)

try:
uri = "mongodb://localhost:{0}/?w={1}".format(self.ports[0], len(self.ports))
conn = ConnectionPool(uri, deadline=2, initial_delay=3)

yield conn.db.coll.insert({'x': 42}, safe=True)

yield self.__mongod[0].stop()

while True:
try:
deferred_call = conn.db.coll.find_one()
self.assertFailure(deferred_call, NetworkTimeout)
break
except AutoReconnect:
pass

self.assertEqual(result['x'], 42)
finally:
yield conn.disconnect()
self.flushLoggedErrors(AutoReconnect)
Expand Down
61 changes: 31 additions & 30 deletions txmongo/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from bson import BSON, ObjectId
from bson.code import Code
from bson.son import SON
from pymongo.errors import InvalidName
from pymongo.errors import InvalidName, NetworkTimeout
from pymongo.helpers import _check_write_command_response
from pymongo.results import InsertOneResult, InsertManyResult, UpdateResult, \
DeleteResult
Expand All @@ -19,6 +19,7 @@
Query, Getmore, Insert, Update, Delete, KillCursors, INSERT_CONTINUE_ON_ERROR
from txmongo import filter as qf
from twisted.internet import defer
from twisted.python import log
from twisted.python.compat import unicode, comparable


Expand Down Expand Up @@ -132,7 +133,6 @@ def options(self):
del options["create"]
defer.returnValue(options)


@defer.inlineCallbacks
def find(self, spec=None, skip=0, limit=0, fields=None, filter=None, cursor=False, **kwargs):
docs, dfr = yield self.find_with_cursor(spec=spec, skip=skip, limit=limit,
Expand All @@ -148,12 +148,13 @@ def find(self, spec=None, skip=0, limit=0, fields=None, filter=None, cursor=Fals

defer.returnValue(result)

def __apply_find_filter(self, spec, filter):
if filter:
@staticmethod
def __apply_find_filter(spec, c_filter):
if c_filter:
if "query" not in spec:
spec = {"$query": spec}

for k, v in filter.items():
for k, v in c_filter.items():
if isinstance(v, (list, tuple)):
spec['$' + k] = dict(v)
else:
Expand Down Expand Up @@ -187,19 +188,19 @@ def find_with_cursor(self, spec=None, skip=0, limit=0, fields=None, filter=None,
spec = self.__apply_find_filter(spec, filter)

as_class = kwargs.get("as_class", dict)
deferred_protocol = self._database.connection.getprotocol()
proto = self._database.connection.getprotocol()

def after_connection(proto):
def after_connection(protocol):
flags = kwargs.get("flags", 0)
query = Query(flags=flags, collection=str(self),
n_to_skip=skip, n_to_return=limit,
query=spec, fields=fields)

deferred_query = proto.send_QUERY(query)
deferred_query.addCallback(after_reply, proto)
deferred_query = protocol.send_QUERY(query)
deferred_query.addCallback(after_reply, protocol)
return deferred_query

def after_reply(reply, proto, fetched=0):
def after_reply(reply, protocol, fetched=0):
documents = reply.documents
docs_count = len(documents)
if limit > 0:
Expand All @@ -222,21 +223,20 @@ def after_reply(reply, proto, fetched=0):
to_fetch = None # close cursor

if to_fetch is None:
proto.send_KILL_CURSORS(KillCursors(cursors=[reply.cursor_id]))
protocol.send_KILL_CURSORS(KillCursors(cursors=[reply.cursor_id]))
return out, defer.succeed(([], None))

next_reply = proto.send_GETMORE(Getmore(
next_reply = protocol.send_GETMORE(Getmore(
collection=str(self), cursor_id=reply.cursor_id,
n_to_return=to_fetch
))
next_reply.addCallback(after_reply, proto, fetched)
next_reply.addCallback(after_reply, protocol, fetched)
return out, next_reply

return out, defer.succeed(([], None))


deferred_protocol.addCallback(after_connection)
return deferred_protocol
proto.addCallback(after_connection)
return proto

@defer.inlineCallbacks
def find_one(self, spec=None, fields=None, **kwargs):
Expand All @@ -245,7 +245,6 @@ def find_one(self, spec=None, fields=None, **kwargs):
result = yield self.find(spec=spec, limit=1, fields=fields, **kwargs)
defer.returnValue(result[0] if result else None)


@defer.inlineCallbacks
def count(self, spec=None, fields=None):
fields = self._normalize_fields_projection(fields)
Expand Down Expand Up @@ -283,25 +282,23 @@ def filemd5(self, spec):
result = yield self._database.command("filemd5", spec, root=self._collection_name)
defer.returnValue(result.get("md5"))


def _get_write_concern(self, safe=None, **wc_options):
from_opts = WriteConcern(**wc_options)
if from_opts.document:
return from_opts

if safe == True:
if safe is None:
return self.write_concern
elif safe:
if self.write_concern.acknowledged:
return self.write_concern
else:
# Edge case: MongoConnection(w=0).db.coll.insert(..., safe=True)
# In this case safe=True must issue getLastError without args
# even if connection-level write concern was unacknowledged
return WriteConcern()
elif safe == False:
return WriteConcern(w=0)

return self.write_concern

return WriteConcern(w=0)

@defer.inlineCallbacks
def insert(self, docs, safe=None, flags=0, **kwargs):
Expand All @@ -324,7 +321,11 @@ def insert(self, docs, safe=None, flags=0, **kwargs):
docs = [BSON.encode(d) for d in docs]
insert = Insert(flags=flags, collection=str(self), documents=docs)

proto = yield self._database.connection.getprotocol()
try:
proto = yield self._database.connection.getprotocol()
except NetworkTimeout as e: # prevent insertion behind the back after a timeout
log.err(str(e))
defer.returnValue(None)

proto.send_INSERT(insert)

Expand Down Expand Up @@ -366,7 +367,6 @@ def insert_many(self, documents, ordered=True):
inserted_ids = yield self._insert_one_or_many(documents, ordered)
defer.returnValue(InsertManyResult(inserted_ids, self.write_concern.acknowledged))


@defer.inlineCallbacks
def update(self, spec, document, upsert=False, multi=False, safe=None, flags=0, **kwargs):
if not isinstance(spec, dict):
Expand All @@ -385,7 +385,12 @@ def update(self, spec, document, upsert=False, multi=False, safe=None, flags=0,
document = BSON.encode(document)
update = Update(flags=flags, collection=str(self),
selector=spec, update=document)
proto = yield self._database.connection.getprotocol()

try:
proto = yield self._database.connection.getprotocol()
except NetworkTimeout as e: # prevent update behind the back after a timeout
log.err(str(e))
defer.returnValue(None)

proto.send_UPDATE(update)

Expand All @@ -394,7 +399,6 @@ def update(self, spec, document, upsert=False, multi=False, safe=None, flags=0,
ret = yield proto.get_last_error(str(self._database), **write_concern.document)
defer.returnValue(ret)


@defer.inlineCallbacks
def _update(self, filter, update, upsert, multi):
validate_is_mapping("filter", filter)
Expand All @@ -420,7 +424,6 @@ def _update(self, filter, update, upsert, multi):

defer.returnValue(raw_response)


@defer.inlineCallbacks
def update_one(self, filter, update, upsert=False):
validate_ok_for_update(update)
Expand Down Expand Up @@ -502,7 +505,6 @@ def delete_many(self, filter):
raw_response = yield self._delete(filter, multi=True)
defer.returnValue(DeleteResult(raw_response, self.write_concern.acknowledged))


def drop(self, **kwargs):
return self._database.drop_collection(self._collection_name)

Expand Down Expand Up @@ -623,7 +625,6 @@ def find_and_modify(self, query=None, update=None, upsert=False, **kwargs):
raise ValueError("Unexpected Error: %s" % (result,))
defer.returnValue(result.get("value"))


# Distinct findAndModify utility method is needed because traditional
# find_and_modify() accepts `sort` kwarg as dict and passes it to
# MongoDB command without conversion. But in find_one_and_*
Expand Down
27 changes: 18 additions & 9 deletions txmongo/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,15 @@

from __future__ import absolute_import, division

from pymongo.errors import AutoReconnect, ConfigurationError, OperationFailure
from pymongo.errors import AutoReconnect, ConfigurationError, OperationFailure, NetworkTimeout
from pymongo.uri_parser import parse_uri
from pymongo.read_preferences import ReadPreference
from pymongo.write_concern import WriteConcern

from time import time
from twisted.internet import defer, reactor, task
from twisted.internet.protocol import ReconnectingClientFactory
from twisted.python import log
from twisted.python.compat import StringType

from txmongo.database import Database
from txmongo.protocol import MongoProtocol, Query

Expand Down Expand Up @@ -245,6 +244,8 @@ def __init__(self, uri="mongodb://127.0.0.1:27017", pool_size=1, ssl_context_fac
assert isinstance(pool_size, int)
assert pool_size >= 1

self.__deadline = kwargs.get('deadline', None)

if not uri.startswith("mongodb://"):
uri = "mongodb://" + uri

Expand All @@ -255,9 +256,9 @@ def __init__(self, uri="mongodb://127.0.0.1:27017", pool_size=1, ssl_context_fac
wc_options = dict((k, v) for k, v in wc_options.items() if k in self.__wc_possible_options)
self.__write_concern = WriteConcern(**wc_options)

initial_timeout = kwargs.get('timeout', 1.0)
retry_delay = kwargs.get('retry_delay', 1.0)
self.__pool_size = pool_size
self.__pool = [_Connection(self, self.__uri, i, initial_timeout) for i in range(pool_size)]
self.__pool = [_Connection(self, self.__uri, i, retry_delay) for i in range(pool_size)]

if self.__uri['database'] and self.__uri['username'] and self.__uri['password']:
self.authenticate(self.__uri['database'], self.__uri['username'],
Expand All @@ -266,13 +267,13 @@ def __init__(self, uri="mongodb://127.0.0.1:27017", pool_size=1, ssl_context_fac

host, port = self.__uri['nodelist'][0]

connection_timeout = kwargs.get('timeout', 30)
initial_delay = kwargs.get('retry_delay', 30)
for factory in self.__pool:
if ssl_context_factory:
factory.connector = reactor.connectSSL(
host, port, factory, ssl_context_factory, connection_timeout)
host, port, factory, ssl_context_factory, initial_delay)
else:
factory.connector = reactor.connectTCP(host, port, factory, connection_timeout)
factory.connector = reactor.connectTCP(host, port, factory, initial_delay)

@property
def write_concern(self):
Expand Down Expand Up @@ -323,9 +324,12 @@ def authenticate(self, database, username, password, mechanism="DEFAULT"):
except defer.FirstError as e:
raise e.subFailure.value


@defer.inlineCallbacks
def getprotocol(self):
# Set our deadline watchdog
if self.__deadline is not None:
start = time()

# Get the next protocol available for communication in the pool.
connection = self.__pool[self.__index]
self.__index = (self.__index + 1) % self.__pool_size
Expand All @@ -336,6 +340,11 @@ def getprotocol(self):

# Wait for the connection to connection.
yield connection.notifyReady()

# Handle deadline
if self.__deadline is not None and time() - start > self.__deadline:
raise NetworkTimeout("MongoDB timeout of {0}s reached.".format(self.__deadline))

defer.returnValue(connection.instance)

@property
Expand Down
Loading

0 comments on commit 1483b79

Please sign in to comment.