From a1f76e9c3e8e15ea161b7c9325783fddce3d7e0a Mon Sep 17 00:00:00 2001 From: Javier de la Rosa Date: Wed, 7 Nov 2012 18:30:57 -0500 Subject: [PATCH] Adding order by and filtering for relationships --- neo4jrestclient/client.py | 49 ++++++++++++++----- neo4jrestclient/constants.py | 3 ++ neo4jrestclient/query.py | 92 +++++++++++++++++++++--------------- neo4jrestclient/tests.py | 56 ++++++++++++++++++---- 4 files changed, 143 insertions(+), 57 deletions(-) diff --git a/neo4jrestclient/client.py b/neo4jrestclient/client.py index d11036a..d07d337 100644 --- a/neo4jrestclient/client.py +++ b/neo4jrestclient/client.py @@ -10,7 +10,7 @@ from lucenequerybuilder import Q import options -from query import Query, Filter, CypherException +from query import QuerySequence, FilterSequence, CypherException from constants import (BREADTH_FIRST, DEPTH_FIRST, STOP_AT_END_OF_GRAPH, NODE_GLOBAL, NODE_PATH, NODE_RECENT, @@ -191,8 +191,8 @@ def query(self, q, params=None, returns=RAW): "path": Path, "position": Position, } - return Query(self._cypher, self._auth, q=q, params=params, - types=types, returns=returns) + return QuerySequence(self._cypher, self._auth, q=q, params=params, + types=types, returns=returns) else: raise CypherException @@ -975,17 +975,17 @@ def delete(self, key, tx=None): node = self.__getitem__(key) del node - def filter(self, lookups=[], start=None, skip=None, limit=None): + def filter(self, lookups=[], start=None): if self._cypher: if start: starts = [] if not isinstance(start, (list, tuple)): - starts = [start] - for start_element in starts: - if isinstance(start_element, Node): - starts.append(start_element.id) + start = [start] + for start_element in start: + if hasattr(start_element, "id"): + starts.append(unicode(start_element.id)) else: - starts.append(start_element) + starts.append(unicode(start_element)) else: starts = u"*" start = u"node(%s)" % u", ".join(starts) @@ -997,8 +997,8 @@ def filter(self, lookups=[], start=None, skip=None, limit=None): "path": Path, "position": Position, } - return Filter(self._cypher, self._auth, start=start, types=types, - lookups=lookups, returns=(Node)) + return FilterSequence(self._cypher, self._auth, start=start, + types=types, lookups=lookups, returns=Node) else: raise CypherException @@ -1658,6 +1658,33 @@ def delete(self, key, tx=None): relationship = self.__getitem__(key) del relationship + def filter(self, lookups=[], start=None): + if self._cypher: + if start: + starts = [] + if not isinstance(start, (list, tuple)): + start = [start] + for start_element in start: + if hasattr(start_element, "id"): + starts.append(unicode(start_element.id)) + else: + starts.append(unicode(start_element)) + else: + starts = u"*" + start = u"rel(%s)" % u", ".join(starts) + if not isinstance(lookups, (list, tuple)): + lookups = [lookups] + types = { + "node": Node, + "relationship": Relationship, + "path": Path, + "position": Position, + } + return FilterSequence(self._cypher, self._auth, start=start, + types=types, lookups=lookups, returns=Node) + else: + raise CypherException + def _indexes(self): if self._relationship_index: return IndexesProxy(self._relationship_index, RELATIONSHIP, diff --git a/neo4jrestclient/constants.py b/neo4jrestclient/constants.py index 26f6fa0..66952da 100644 --- a/neo4jrestclient/constants.py +++ b/neo4jrestclient/constants.py @@ -42,6 +42,9 @@ INDEX_RELATIONSHIP = "index_relationship" INDEX_EXACT = "exact" INDEX_FULLTEXT = "fulltext" +# Cypher ordering +ASC = "asc" +DESC = "desc" # Transactions TX_GET = "GET" TX_PUT = "PUT" diff --git a/neo4jrestclient/query.py b/neo4jrestclient/query.py index 247dab7..e04a2d5 100644 --- a/neo4jrestclient/query.py +++ b/neo4jrestclient/query.py @@ -27,7 +27,7 @@ class BaseQ(object): "eq", "equals", "neq", "notequals") def __init__(self, property=None, lookup=None, match=None, - nullable=None, var=u"n", **kwargs): + nullable=True, var=u"n", **kwargs): self._and = None self._or = None self._not = None @@ -137,16 +137,16 @@ def _get_lookup_and_match(self): match = u"(?i){0}".format(self.match) elif self.lookup == "gt": lookup = u">" - match = u"{0}".format(self.match) + match = self.match elif self.lookup == "gte": - lookup = u">" - match = u"{0}".format(self.match) + lookup = u">=" + match = self.match elif self.lookup == "lt": lookup = u"<" - match = u"{0}".format(self.match) + match = self.match elif self.lookup == "lte": - lookup = u"<" - match = u"{0}".format(self.match) + lookup = u"<=" + match = self.match elif self.lookup in ["in", "inrange"]: lookup = u"IN" match = u"['{0}']".format(u"', '".join([self._escape(m) @@ -234,14 +234,16 @@ class CypherException(Exception): pass -class Query(Sequence): +class QuerySequence(Sequence): def __init__(self, cypher, auth, q, params=None, types=None, returns=None): self.q = q self.params = params - self.skip = None - self.limit = None - self.returns = returns + self._skip = None + self._limit = None + self._order_by = None + self._returns = returns + self._return_single_rows = False self._auth = auth self._cypher = cypher # This way we avoid a circular reference, by passing objects like Node @@ -253,7 +255,7 @@ def _get_elements(self): response = self.get_response() try: self._elements = self.cast(elements=response["data"], - returns=self.returns) + returns=self._returns) except: self._elements = response return self._elements @@ -275,14 +277,32 @@ def __reversed__(self): return reversed(self.elements) def get_response(self): + # Preparing slicing and ordering q = self.q params = self.params - if isinstance(self.skip, int) and "_skip" not in params: + if self._order_by: + orders = [] + for o, order in enumerate(self._order_by): + order_key = "_order_by_%s" % o + if order_key not in params: + nullable = "" + if len(order) == 3: + if order[2] is True: + nullable = "!" + elif order[2] is False: + nullable = "?" + orders.append(u"n.`{%s}`%s %s" % (order_key, nullable, order[1])) + params[order_key] = order[0] + if orders: + q = u"%s order by %s" % (q, ", ".join(orders)) + # Lazy slicing + if isinstance(self._skip, int) and "_skip" not in params: q = u"%s skip {_skip} " % q - params["_skip"] = self.skip - if isinstance(self.limit, int) and "_limit" not in params: + params["_skip"] = self._skip + if isinstance(self._limit, int) and "_limit" not in params: q = u"%s limit {_limit} " % q - params["_limit"] = self.limit + params["_limit"] = self._limit + # Making the real resquest data = { "query": q, "params": params, @@ -342,11 +362,14 @@ def cast(self, elements, returns=None): casted_row.append(sub_func(element)) else: casted_row.append(func(element)) - results.append(casted_row) + if self._return_single_rows: + results.append(*casted_row) + else: + results.append(casted_row) return results -class Filter(Query): +class FilterSequence(QuerySequence): def __init__(self, cypher, auth, start=None, lookups=[], order_by=None, types=None, returns=None): @@ -366,25 +389,20 @@ def __init__(self, cypher, auth, start=None, lookups=[], q = u"%s where %s return n " % (q, where) else: q = u"%s return n " % q - if order_by: - if not isinstance(order_by, (tuple, list)): - order_by = [order_by] - orders = [] - for o, order in enumerate(order_by): - if order.startswith(u"-"): - orders.append(u"n.`{order_by_%s}` desc" % o) - else: - orders.append(u"n.`{order_by_%s}` " % o) - params["order_by_%s" % o] = order - if orders: - q = u"%s order by %s" % (q, ", ".join(orders)) - params["start"] = start - super(Filter, self).__init__(cypher=cypher, auth=auth, q=q, - params=params, types=types, - returns=returns) + super(FilterSequence, self).__init__(cypher=cypher, auth=auth, q=q, + params=params, types=types, + returns=returns) + self._return_single_rows = True def __getitem__(self, key): if isinstance(key, slice): - self.skip = key.start - self.limit = key.stop - return super(Filter, self).__getitem__(key) + self._skip = key.start + self._limit = key.stop + return super(FilterSequence, self).__getitem__(key) + + def order_by(self, property=None, type=None, nullable=True, *args): + if property is None and isinstance(args, (list, tuple)): + self._order_by = args + else: + self._order_by = [(property, type, nullable)] + return self diff --git a/neo4jrestclient/tests.py b/neo4jrestclient/tests.py index 89d562c..0f3b58f 100644 --- a/neo4jrestclient/tests.py +++ b/neo4jrestclient/tests.py @@ -1258,7 +1258,7 @@ def test_relationship_pickle(self): self.assertEqual(r, pickle.loads(p)) -class QueryAndFilterTestCase(NodesTestCase): +class QueryTestCase(PickleTestCase): def test_query_raw(self): n1 = self.gdb.nodes.create(name="John") @@ -1313,11 +1313,14 @@ def test_query_params_returns_tuple(self): self.assertEqual(rel, r) self.assertEqual(date, 1982) + +class FilterTestCase(QueryTestCase): + def test_filter_nodes(self): Q = query.Q for i in range(5): self.gdb.nodes.create(name="William %s" % i) - lookup = Q("name", istartswith="william", nullable=True) + lookup = Q("name", istartswith="william") williams = self.gdb.nodes.filter(lookup) self.assertTrue(len(williams) >= 5) @@ -1326,24 +1329,59 @@ def test_filter_nodes_complex_lookups(self): for i in range(5): self.gdb.nodes.create(name="James", surname="Smith %s" % i) lookups = ( - Q("name", exact="James", nullable=True) & - (Q("surname", startswith="Smith", nullable=True) & - ~Q("surname", endswith="1", nullable=True)) + Q("name", exact="James") & + (Q("surname", startswith="Smith") & + ~Q("surname", endswith="1")) ) williams = self.gdb.nodes.filter(lookups) self.assertTrue(len(williams) >= 5) - def test_filter_slicing(self): + def test_filter_nodes_slicing(self): Q = query.Q for i in range(5): self.gdb.nodes.create(name="William %s" % i) - lookup = Q("name", istartswith="william", nullable=True) - import ipdb; ipdb.set_trace() + lookup = Q("name", istartswith="william") williams = self.gdb.nodes.filter(lookup)[:4] self.assertTrue(len(williams) == 4) + def test_filter_nodes_ordering(self): + Q = query.Q + for i in range(5): + self.gdb.nodes.create(name="William", code=i) + lookup = Q("code", gte=2) + williams = self.gdb.nodes.filter(lookup).order_by("code", + constants.DESC) + self.assertTrue(williams[-1]["code"] > williams[0]["code"]) + + def test_filter_nodes_nullable(self): + Q = query.Q + for i in range(5): + self.gdb.nodes.create(name="William %s" % i) + lookup = Q("name", istartswith="william", nullable=False) + williams = self.gdb.nodes.filter(lookup)[:10] + self.assertTrue(len(williams) > 5) + + def test_filter_nodes_start(self): + Q = query.Q + nodes = [] + for i in range(5): + nodes.append(self.gdb.nodes.create(name="William %s" % i)) + lookup = Q("name", istartswith="william") + williams = self.gdb.nodes.filter(lookup, start=nodes) + self.assertTrue(len(williams) == 5) + + def test_filter_relationships(self): + Q = query.Q + for i in range(10): + n1 = self.gdb.nodes.create(name="William %s" % i) + n2 = self.gdb.nodes.create(name="Rose %s" % i) + n1.loves(n2, since=(1995 + i)) + lookup = Q("since", lte=2000) + old_loves = self.gdb.relationships.filter(lookup) + self.assertTrue(len(old_loves) >= 5) + -class Neo4jPythonClientTestCase(QueryAndFilterTestCase): +class Neo4jPythonClientTestCase(FilterTestCase): pass