Skip to content

Commit

Permalink
Merge pull request #105 from IlyaSkriblovsky/inlineCallbacks-everywhere
Browse files Browse the repository at this point in the history
inlineCallbacks everywhere
  • Loading branch information
psi29a committed May 26, 2015
2 parents 702f764 + 1e692ab commit bd50ab9
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 120 deletions.
6 changes: 3 additions & 3 deletions tests/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def make_col(base, name):
self.assertRaises(errors.InvalidName, make_col, self.db.test, "tes..t")
self.assertRaises(errors.InvalidName, make_col, self.db.test, "tes\x00t")
self.assertRaises(TypeError, self.coll.save, "test")
self.assertRaises(ValueError, self.coll.filemd5, "test")
self.assertFailure(self.coll.filemd5("test"), ValueError)
self.assertFailure(self.db.test.find(spec="test"), TypeError)
self.assertFailure(self.db.test.find(fields="test"), TypeError)
self.assertFailure(self.db.test.find(skip="test"), TypeError)
Expand Down Expand Up @@ -97,8 +97,8 @@ def test_create_index(self):
db = self.db
coll = self.coll

self.assertRaises(TypeError, coll.create_index, 5)
self.assertRaises(TypeError, coll.create_index, {"hello": 1})
self.assertFailure(coll.create_index(5), TypeError)
self.assertFailure(coll.create_index({"hello": 1}), TypeError)

yield coll.insert({'c': 1}) # make sure collection exists.

Expand Down
7 changes: 3 additions & 4 deletions tests/test_find_and_modify.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,9 @@ def test_Update(self):
self.assertEqual(res["lulz"], 456)

def test_InvalidOptions(self):
self.assertRaises(ValueError, self.coll.find_and_modify)
self.assertRaises(ValueError, self.coll.find_and_modify,
update={"$set": {'x': 42}},
remove=True)
self.assertFailure(self.coll.find_and_modify(), ValueError)
self.assertFailure(self.coll.find_and_modify(update={"$set": {'x': 42}}, remove=True),
ValueError)

@defer.inlineCallbacks
def test_Remove(self):
Expand Down
132 changes: 54 additions & 78 deletions txmongo/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,18 +71,15 @@ def _fields_list_to_dict(fields):
def _gen_index_name(keys):
return u'_'.join([u"%s_%s" % item for item in keys])

@defer.inlineCallbacks
def options(self):
def wrapper(result):
if result:
options = result.get("options", {})
if "create" in options:
del options["create"]
return options
return {}

deferred_find_one = self._database.system.namespaces.find_one({"name": str(self)})
deferred_find_one.addCallback(wrapper)
return deferred_find_one
result = yield self._database.system.namespaces.find_one({"name": str(self)})
if not result:
result = {}
options = result.get("options", {})
if "create" in options:
del options["create"]
defer.returnValue(options)


@defer.inlineCallbacks
Expand Down Expand Up @@ -200,27 +197,25 @@ def after_reply(reply, proto, fetched=0):
deferred_protocol.addCallback(after_connection)
return deferred_protocol

@defer.inlineCallbacks
def find_one(self, spec=None, fields=None, **kwargs):
if isinstance(spec, ObjectId):
spec = {"_id": spec}
deferred_find = self.find(spec=spec, limit=1, fields=fields, **kwargs)
deferred_find.addCallback(lambda r: r[0] if r else {})
return deferred_find
result = yield self.find(spec=spec, limit=1, fields=fields, **kwargs)
defer.returnValue(result[0] if result else {})

def count(self, spec=None, fields=None):
def wrapper(result):
return result["n"]

@defer.inlineCallbacks
def count(self, spec=None, fields=None):
if fields is not None:
if not fields:
fields = ["_id"]
fields = self._fields_list_to_dict(fields)

deferred_find_one = self._database.command("count", self._collection_name,
query=spec or SON(),
fields=fields)
deferred_find_one.addCallback(wrapper)
return deferred_find_one
result = yield self._database.command("count", self._collection_name,
query=spec or SON(),
fields=fields)
defer.returnValue(result["n"])

def group(self, keys, initial, reduce, condition=None, finalize=None):
body = {
Expand All @@ -241,17 +236,15 @@ def group(self, keys, initial, reduce, condition=None, finalize=None):

return self._database.command("group", body)

@defer.inlineCallbacks
def filemd5(self, spec):
def wrapper(result):
return result.get("md5")

if not isinstance(spec, ObjectId):
raise ValueError("filemd5 expected an objectid for its "
"non-keyword argument")

deferred_fine_one = self._database.command("filemd5", spec, root=self._collection_name)
deferred_fine_one.addCallback(wrapper)
return deferred_fine_one
result = yield self._database.command("filemd5", spec, root=self._collection_name)
defer.returnValue(result.get("md5"))


@defer.inlineCallbacks
def insert(self, docs, safe=True, **kwargs):
Expand Down Expand Up @@ -346,10 +339,8 @@ def remove(self, spec, safe=True, single=False, **kwargs):
def drop(self, **kwargs):
return self._database.drop_collection(self._collection_name)

@defer.inlineCallbacks
def create_index(self, sort_fields, **kwargs):
def wrapper(result, name):
return name

if not isinstance(sort_fields, qf.sort):
raise TypeError("sort_fields must be an instance of filter.sort")

Expand All @@ -375,9 +366,8 @@ def wrapper(result, name):
kwargs["bucketSize"] = kwargs.pop("bucket_size")

index.update(kwargs)
deferred_insert = self._database.system.indexes.insert(index, safe=True)
deferred_insert.addCallback(wrapper, name)
return deferred_insert
yield self._database.system.indexes.insert(index, safe=True)
defer.returnValue(name)

def ensure_index(self, sort_fields, **kwargs):
# ensure_index is an alias of create_index since we are not
Expand All @@ -399,65 +389,46 @@ def drop_index(self, index_identifier):
def drop_indexes(self):
return self.drop_index("*")

@defer.inlineCallbacks
def index_information(self):
def wrapper(raw):
info = {}
for idx in raw:
info[idx["name"]] = idx
return info

deferred_find = self._database.system.indexes.find({"ns": str(self)})
deferred_find.addCallback(wrapper)
return deferred_find
raw = yield self._database.system.indexes.find({"ns": str(self)})
info = {}
for idx in raw:
info[idx["name"]] = idx
defer.returnValue(info)

def rename(self, new_name):
to = "%s.%s" % (str(self._database), new_name)
return self._database("admin").command("renameCollection", str(self), to=to)

@defer.inlineCallbacks
def distinct(self, key, spec=None):
def wrapper(result):
return result.get("values")

params = {"key": key}
if spec:
params["query"] = spec

d = self._database.command("distinct", self._collection_name, **params)
d.addCallback(wrapper)
return d
result = yield self._database.command("distinct", self._collection_name, **params)
defer.returnValue(result.get("values"))

@defer.inlineCallbacks
def aggregate(self, pipeline, full_response=False):
def wrapper(result, full_response):
if full_response:
return result
return result.get("result")

d = self._database.command("aggregate", self._collection_name, pipeline=pipeline)
d.addCallback(wrapper, full_response)
return d
raw = yield self._database.command("aggregate", self._collection_name, pipeline=pipeline)
if full_response:
defer.returnValue(raw)
defer.returnValue(raw.get("result"))

@defer.inlineCallbacks
def map_reduce(self, map, reduce, full_response=False, **kwargs):
def wrapper(result, full_response):
if full_response:
return result
return result.get("results")

params = {"map": map, "reduce": reduce}
params.update(**kwargs)
deferred_find_one = self._database.command("mapreduce", self._collection_name, **params)
deferred_find_one.addCallback(wrapper, full_response)
return deferred_find_one
raw = yield self._database.command("mapreduce", self._collection_name, **params)
if full_response:
defer.returnValue(raw)
defer.returnValue(raw.get("results"))

@defer.inlineCallbacks
def find_and_modify(self, query=None, update=None, upsert=False, **kwargs):
no_obj_error = "No matching object found"
def wrapper(result):
if not result["ok"]:
if result["errmsg"] == no_obj_error:
return None
else:
# Should never get here because of allowable_errors
raise ValueError("Unexpected Error: %s" % (result,))
return result.get("value")

if not update and not kwargs.get("remove", None):
raise ValueError("Must either update or remove")
Expand All @@ -474,8 +445,13 @@ def wrapper(result):
if upsert:
params["upsert"] = upsert

deferred_find_one = self._database.command("findAndModify", self._collection_name,
allowable_errors=[no_obj_error],
**params)
deferred_find_one.addCallback(wrapper)
return deferred_find_one
result = yield self._database.command("findAndModify", self._collection_name,
allowable_errors=[no_obj_error],
**params)
if not result["ok"]:
if result["errmsg"] == no_obj_error:
defer.returnValue(None)
else:
# Should never get here because of allowable_errors
raise ValueError("Unexpected Error: %s" % (result,))
defer.returnValue(result.get("value"))
29 changes: 15 additions & 14 deletions txmongo/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,25 +65,24 @@ def _initializeProto(self, proto):
except Exception as e:
proto.fail(e)

@defer.inlineCallbacks
def configure(self, proto):
"""
Configures the protocol using the information gathered from the
remote Mongo instance. Such information may contain the max
BSON document size, replica set configuration, and the master
status of the instance.
"""
if proto:
query = Query(collection="admin.$cmd", query={"ismaster": 1})
df = proto.send_QUERY(query)
df.addCallback(self._configureCallback, proto)
return df
return defer.succeed(None)

def _configureCallback(self, reply, proto):
"""
Handle the reply from the "ismaster" query. The reply contains
configuration information about the peer.
"""

if not proto:
defer.returnValue(None)

query = Query(collection="admin.$cmd", query={"ismaster": 1})
reply = yield proto.send_QUERY(query)

# Handle the reply from the "ismaster" query. The reply contains
# configuration information about the peer.

# Make sure we got a result document.
if len(reply.documents) != 1:
raise OperationFailure("Invalid document length.")
Expand Down Expand Up @@ -304,17 +303,19 @@ def authenticate(self, database, username, password, mechanism="DEFAULT"):
raise e.subFailure


@defer.inlineCallbacks
def getprotocol(self):
# Get the next protocol available for communication in the pool.
connection = self.__pool[self.__index]
self.__index = (self.__index + 1) % self.__pool_size

# If the connection is already connected, just return it.
if connection.instance:
return defer.succeed(connection.instance)
defer.returnValue(connection.instance)

# Wait for the connection to connection.
return connection.notifyReady().addCallback(lambda c: c.instance)
yield connection.notifyReady()
defer.returnValue(connection.instance)

@property
def uri(self):
Expand Down
33 changes: 12 additions & 21 deletions txmongo/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,24 +53,16 @@ def command(self, command, value=1, check=True, allowable_errors=None, **kwargs)

defer.returnValue(response)

@defer.inlineCallbacks
def create_collection(self, name, options=None):
def wrapper(result, deferred, collection):
deferred.callback(collection)

deferred = defer.Deferred()
collection = Collection(self, name)

if options:
if "size" in options:
options["size"] = float(options["size"])
yield self.command("create", name, **options)

d = self.command("create", name, **options)
d.addCallback(wrapper, deferred, collection)
d.addErrback(deferred.errback)
else:
deferred.callback(collection)

return deferred
defer.returnValue(collection)

def drop_collection(self, name_or_collection):
if isinstance(name_or_collection, Collection):
Expand All @@ -82,17 +74,16 @@ def drop_collection(self, name_or_collection):

return self.command("drop", unicode(name), allowable_errors=["ns not found"])

@defer.inlineCallbacks
def collection_names(self):
def wrapper(results):
names = [r["name"] for r in results]
names = [n[len(str(self)) + 1:] for n in names
if n.startswith(str(self) + ".")]
names = [n for n in names if "$" not in n]
return names

d = self["system.namespaces"].find()
d.addCallback(wrapper)
return d
results = yield self["system.namespaces"].find()

names = [r["name"] for r in results]
names = [n[len(str(self)) + 1:] for n in names
if n.startswith(str(self) + ".")]
names = [n for n in names if "$" not in n]
defer.returnValue(names)


@defer.inlineCallbacks
def authenticate(self, name, password):
Expand Down

0 comments on commit bd50ab9

Please sign in to comment.