Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Adding order by and filtering for relationships
  • Loading branch information
versae committed Nov 7, 2012
1 parent b2c5c38 commit a1f76e9
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 57 deletions.
49 changes: 38 additions & 11 deletions neo4jrestclient/client.py
Expand Up @@ -10,7 +10,7 @@
from lucenequerybuilder import Q from lucenequerybuilder import Q


import options import options
from query import Query, Filter, CypherException from query import QuerySequence, FilterSequence, CypherException
from constants import (BREADTH_FIRST, DEPTH_FIRST, from constants import (BREADTH_FIRST, DEPTH_FIRST,
STOP_AT_END_OF_GRAPH, STOP_AT_END_OF_GRAPH,
NODE_GLOBAL, NODE_PATH, NODE_RECENT, NODE_GLOBAL, NODE_PATH, NODE_RECENT,
Expand Down Expand Up @@ -191,8 +191,8 @@ def query(self, q, params=None, returns=RAW):
"path": Path, "path": Path,
"position": Position, "position": Position,
} }
return Query(self._cypher, self._auth, q=q, params=params, return QuerySequence(self._cypher, self._auth, q=q, params=params,
types=types, returns=returns) types=types, returns=returns)
else: else:
raise CypherException raise CypherException


Expand Down Expand Up @@ -975,17 +975,17 @@ def delete(self, key, tx=None):
node = self.__getitem__(key) node = self.__getitem__(key)
del node del node


def filter(self, lookups=[], start=None, skip=None, limit=None): def filter(self, lookups=[], start=None):
if self._cypher: if self._cypher:
if start: if start:
starts = [] starts = []
if not isinstance(start, (list, tuple)): if not isinstance(start, (list, tuple)):
starts = [start] start = [start]
for start_element in starts: for start_element in start:
if isinstance(start_element, Node): if hasattr(start_element, "id"):
starts.append(start_element.id) starts.append(unicode(start_element.id))
else: else:
starts.append(start_element) starts.append(unicode(start_element))
else: else:
starts = u"*" starts = u"*"
start = u"node(%s)" % u", ".join(starts) start = u"node(%s)" % u", ".join(starts)
Expand All @@ -997,8 +997,8 @@ def filter(self, lookups=[], start=None, skip=None, limit=None):
"path": Path, "path": Path,
"position": Position, "position": Position,
} }
return Filter(self._cypher, self._auth, start=start, types=types, return FilterSequence(self._cypher, self._auth, start=start,
lookups=lookups, returns=(Node)) types=types, lookups=lookups, returns=Node)
else: else:
raise CypherException raise CypherException


Expand Down Expand Up @@ -1658,6 +1658,33 @@ def delete(self, key, tx=None):
relationship = self.__getitem__(key) relationship = self.__getitem__(key)
del relationship del relationship


def filter(self, lookups=[], start=None):
if self._cypher:
if start:
starts = []
if not isinstance(start, (list, tuple)):
start = [start]
for start_element in start:
if hasattr(start_element, "id"):
starts.append(unicode(start_element.id))
else:
starts.append(unicode(start_element))
else:
starts = u"*"
start = u"rel(%s)" % u", ".join(starts)
if not isinstance(lookups, (list, tuple)):
lookups = [lookups]
types = {
"node": Node,
"relationship": Relationship,
"path": Path,
"position": Position,
}
return FilterSequence(self._cypher, self._auth, start=start,
types=types, lookups=lookups, returns=Node)
else:
raise CypherException

def _indexes(self): def _indexes(self):
if self._relationship_index: if self._relationship_index:
return IndexesProxy(self._relationship_index, RELATIONSHIP, return IndexesProxy(self._relationship_index, RELATIONSHIP,
Expand Down
3 changes: 3 additions & 0 deletions neo4jrestclient/constants.py
Expand Up @@ -42,6 +42,9 @@
INDEX_RELATIONSHIP = "index_relationship" INDEX_RELATIONSHIP = "index_relationship"
INDEX_EXACT = "exact" INDEX_EXACT = "exact"
INDEX_FULLTEXT = "fulltext" INDEX_FULLTEXT = "fulltext"
# Cypher ordering
ASC = "asc"
DESC = "desc"
# Transactions # Transactions
TX_GET = "GET" TX_GET = "GET"
TX_PUT = "PUT" TX_PUT = "PUT"
Expand Down
92 changes: 55 additions & 37 deletions neo4jrestclient/query.py
Expand Up @@ -27,7 +27,7 @@ class BaseQ(object):
"eq", "equals", "neq", "notequals") "eq", "equals", "neq", "notequals")


def __init__(self, property=None, lookup=None, match=None, def __init__(self, property=None, lookup=None, match=None,
nullable=None, var=u"n", **kwargs): nullable=True, var=u"n", **kwargs):
self._and = None self._and = None
self._or = None self._or = None
self._not = None self._not = None
Expand Down Expand Up @@ -137,16 +137,16 @@ def _get_lookup_and_match(self):
match = u"(?i){0}".format(self.match) match = u"(?i){0}".format(self.match)
elif self.lookup == "gt": elif self.lookup == "gt":
lookup = u">" lookup = u">"
match = u"{0}".format(self.match) match = self.match
elif self.lookup == "gte": elif self.lookup == "gte":
lookup = u">" lookup = u">="
match = u"{0}".format(self.match) match = self.match
elif self.lookup == "lt": elif self.lookup == "lt":
lookup = u"<" lookup = u"<"
match = u"{0}".format(self.match) match = self.match
elif self.lookup == "lte": elif self.lookup == "lte":
lookup = u"<" lookup = u"<="
match = u"{0}".format(self.match) match = self.match
elif self.lookup in ["in", "inrange"]: elif self.lookup in ["in", "inrange"]:
lookup = u"IN" lookup = u"IN"
match = u"['{0}']".format(u"', '".join([self._escape(m) match = u"['{0}']".format(u"', '".join([self._escape(m)
Expand Down Expand Up @@ -234,14 +234,16 @@ class CypherException(Exception):
pass pass




class Query(Sequence): class QuerySequence(Sequence):


def __init__(self, cypher, auth, q, params=None, types=None, returns=None): def __init__(self, cypher, auth, q, params=None, types=None, returns=None):
self.q = q self.q = q
self.params = params self.params = params
self.skip = None self._skip = None
self.limit = None self._limit = None
self.returns = returns self._order_by = None
self._returns = returns
self._return_single_rows = False
self._auth = auth self._auth = auth
self._cypher = cypher self._cypher = cypher
# This way we avoid a circular reference, by passing objects like Node # This way we avoid a circular reference, by passing objects like Node
Expand All @@ -253,7 +255,7 @@ def _get_elements(self):
response = self.get_response() response = self.get_response()
try: try:
self._elements = self.cast(elements=response["data"], self._elements = self.cast(elements=response["data"],
returns=self.returns) returns=self._returns)
except: except:
self._elements = response self._elements = response
return self._elements return self._elements
Expand All @@ -275,14 +277,32 @@ def __reversed__(self):
return reversed(self.elements) return reversed(self.elements)


def get_response(self): def get_response(self):
# Preparing slicing and ordering
q = self.q q = self.q
params = self.params params = self.params
if isinstance(self.skip, int) and "_skip" not in params: if self._order_by:
orders = []
for o, order in enumerate(self._order_by):
order_key = "_order_by_%s" % o
if order_key not in params:
nullable = ""
if len(order) == 3:
if order[2] is True:
nullable = "!"
elif order[2] is False:
nullable = "?"
orders.append(u"n.`{%s}`%s %s" % (order_key, nullable, order[1]))
params[order_key] = order[0]
if orders:
q = u"%s order by %s" % (q, ", ".join(orders))
# Lazy slicing
if isinstance(self._skip, int) and "_skip" not in params:
q = u"%s skip {_skip} " % q q = u"%s skip {_skip} " % q
params["_skip"] = self.skip params["_skip"] = self._skip
if isinstance(self.limit, int) and "_limit" not in params: if isinstance(self._limit, int) and "_limit" not in params:
q = u"%s limit {_limit} " % q q = u"%s limit {_limit} " % q
params["_limit"] = self.limit params["_limit"] = self._limit
# Making the real resquest
data = { data = {
"query": q, "query": q,
"params": params, "params": params,
Expand Down Expand Up @@ -342,11 +362,14 @@ def cast(self, elements, returns=None):
casted_row.append(sub_func(element)) casted_row.append(sub_func(element))
else: else:
casted_row.append(func(element)) casted_row.append(func(element))
results.append(casted_row) if self._return_single_rows:
results.append(*casted_row)
else:
results.append(casted_row)
return results return results




class Filter(Query): class FilterSequence(QuerySequence):


def __init__(self, cypher, auth, start=None, lookups=[], def __init__(self, cypher, auth, start=None, lookups=[],
order_by=None, types=None, returns=None): order_by=None, types=None, returns=None):
Expand All @@ -366,25 +389,20 @@ def __init__(self, cypher, auth, start=None, lookups=[],
q = u"%s where %s return n " % (q, where) q = u"%s where %s return n " % (q, where)
else: else:
q = u"%s return n " % q q = u"%s return n " % q
if order_by: super(FilterSequence, self).__init__(cypher=cypher, auth=auth, q=q,
if not isinstance(order_by, (tuple, list)): params=params, types=types,
order_by = [order_by] returns=returns)
orders = [] self._return_single_rows = True
for o, order in enumerate(order_by):
if order.startswith(u"-"):
orders.append(u"n.`{order_by_%s}` desc" % o)
else:
orders.append(u"n.`{order_by_%s}` " % o)
params["order_by_%s" % o] = order
if orders:
q = u"%s order by %s" % (q, ", ".join(orders))
params["start"] = start
super(Filter, self).__init__(cypher=cypher, auth=auth, q=q,
params=params, types=types,
returns=returns)


def __getitem__(self, key): def __getitem__(self, key):
if isinstance(key, slice): if isinstance(key, slice):
self.skip = key.start self._skip = key.start
self.limit = key.stop self._limit = key.stop
return super(Filter, self).__getitem__(key) return super(FilterSequence, self).__getitem__(key)

def order_by(self, property=None, type=None, nullable=True, *args):
if property is None and isinstance(args, (list, tuple)):
self._order_by = args
else:
self._order_by = [(property, type, nullable)]
return self
56 changes: 47 additions & 9 deletions neo4jrestclient/tests.py
Expand Up @@ -1258,7 +1258,7 @@ def test_relationship_pickle(self):
self.assertEqual(r, pickle.loads(p)) self.assertEqual(r, pickle.loads(p))




class QueryAndFilterTestCase(NodesTestCase): class QueryTestCase(PickleTestCase):


def test_query_raw(self): def test_query_raw(self):
n1 = self.gdb.nodes.create(name="John") n1 = self.gdb.nodes.create(name="John")
Expand Down Expand Up @@ -1313,11 +1313,14 @@ def test_query_params_returns_tuple(self):
self.assertEqual(rel, r) self.assertEqual(rel, r)
self.assertEqual(date, 1982) self.assertEqual(date, 1982)



class FilterTestCase(QueryTestCase):

def test_filter_nodes(self): def test_filter_nodes(self):
Q = query.Q Q = query.Q
for i in range(5): for i in range(5):
self.gdb.nodes.create(name="William %s" % i) self.gdb.nodes.create(name="William %s" % i)
lookup = Q("name", istartswith="william", nullable=True) lookup = Q("name", istartswith="william")
williams = self.gdb.nodes.filter(lookup) williams = self.gdb.nodes.filter(lookup)
self.assertTrue(len(williams) >= 5) self.assertTrue(len(williams) >= 5)


Expand All @@ -1326,24 +1329,59 @@ def test_filter_nodes_complex_lookups(self):
for i in range(5): for i in range(5):
self.gdb.nodes.create(name="James", surname="Smith %s" % i) self.gdb.nodes.create(name="James", surname="Smith %s" % i)
lookups = ( lookups = (
Q("name", exact="James", nullable=True) & Q("name", exact="James") &
(Q("surname", startswith="Smith", nullable=True) & (Q("surname", startswith="Smith") &
~Q("surname", endswith="1", nullable=True)) ~Q("surname", endswith="1"))
) )
williams = self.gdb.nodes.filter(lookups) williams = self.gdb.nodes.filter(lookups)
self.assertTrue(len(williams) >= 5) self.assertTrue(len(williams) >= 5)


def test_filter_slicing(self): def test_filter_nodes_slicing(self):
Q = query.Q Q = query.Q
for i in range(5): for i in range(5):
self.gdb.nodes.create(name="William %s" % i) self.gdb.nodes.create(name="William %s" % i)
lookup = Q("name", istartswith="william", nullable=True) lookup = Q("name", istartswith="william")
import ipdb; ipdb.set_trace()
williams = self.gdb.nodes.filter(lookup)[:4] williams = self.gdb.nodes.filter(lookup)[:4]
self.assertTrue(len(williams) == 4) self.assertTrue(len(williams) == 4)


def test_filter_nodes_ordering(self):
Q = query.Q
for i in range(5):
self.gdb.nodes.create(name="William", code=i)
lookup = Q("code", gte=2)
williams = self.gdb.nodes.filter(lookup).order_by("code",
constants.DESC)
self.assertTrue(williams[-1]["code"] > williams[0]["code"])

def test_filter_nodes_nullable(self):
Q = query.Q
for i in range(5):
self.gdb.nodes.create(name="William %s" % i)
lookup = Q("name", istartswith="william", nullable=False)
williams = self.gdb.nodes.filter(lookup)[:10]
self.assertTrue(len(williams) > 5)

def test_filter_nodes_start(self):
Q = query.Q
nodes = []
for i in range(5):
nodes.append(self.gdb.nodes.create(name="William %s" % i))
lookup = Q("name", istartswith="william")
williams = self.gdb.nodes.filter(lookup, start=nodes)
self.assertTrue(len(williams) == 5)

def test_filter_relationships(self):
Q = query.Q
for i in range(10):
n1 = self.gdb.nodes.create(name="William %s" % i)
n2 = self.gdb.nodes.create(name="Rose %s" % i)
n1.loves(n2, since=(1995 + i))
lookup = Q("since", lte=2000)
old_loves = self.gdb.relationships.filter(lookup)
self.assertTrue(len(old_loves) >= 5)



class Neo4jPythonClientTestCase(QueryAndFilterTestCase): class Neo4jPythonClientTestCase(FilterTestCase):
pass pass




Expand Down

0 comments on commit a1f76e9

Please sign in to comment.