Skip to content

Commit

Permalink
Adding order by and filtering for relationships
Browse files Browse the repository at this point in the history
  • Loading branch information
versae committed Nov 7, 2012
1 parent b2c5c38 commit a1f76e9
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 57 deletions.
49 changes: 38 additions & 11 deletions neo4jrestclient/client.py
Expand Up @@ -10,7 +10,7 @@
from lucenequerybuilder import Q

import options
from query import Query, Filter, CypherException
from query import QuerySequence, FilterSequence, CypherException
from constants import (BREADTH_FIRST, DEPTH_FIRST,
STOP_AT_END_OF_GRAPH,
NODE_GLOBAL, NODE_PATH, NODE_RECENT,
Expand Down Expand Up @@ -191,8 +191,8 @@ def query(self, q, params=None, returns=RAW):
"path": Path,
"position": Position,
}
return Query(self._cypher, self._auth, q=q, params=params,
types=types, returns=returns)
return QuerySequence(self._cypher, self._auth, q=q, params=params,
types=types, returns=returns)
else:
raise CypherException

Expand Down Expand Up @@ -975,17 +975,17 @@ def delete(self, key, tx=None):
node = self.__getitem__(key)
del node

def filter(self, lookups=[], start=None, skip=None, limit=None):
def filter(self, lookups=[], start=None):
if self._cypher:
if start:
starts = []
if not isinstance(start, (list, tuple)):
starts = [start]
for start_element in starts:
if isinstance(start_element, Node):
starts.append(start_element.id)
start = [start]
for start_element in start:
if hasattr(start_element, "id"):
starts.append(unicode(start_element.id))
else:
starts.append(start_element)
starts.append(unicode(start_element))
else:
starts = u"*"
start = u"node(%s)" % u", ".join(starts)
Expand All @@ -997,8 +997,8 @@ def filter(self, lookups=[], start=None, skip=None, limit=None):
"path": Path,
"position": Position,
}
return Filter(self._cypher, self._auth, start=start, types=types,
lookups=lookups, returns=(Node))
return FilterSequence(self._cypher, self._auth, start=start,
types=types, lookups=lookups, returns=Node)
else:
raise CypherException

Expand Down Expand Up @@ -1658,6 +1658,33 @@ def delete(self, key, tx=None):
relationship = self.__getitem__(key)
del relationship

def filter(self, lookups=[], start=None):
if self._cypher:
if start:
starts = []
if not isinstance(start, (list, tuple)):
start = [start]
for start_element in start:
if hasattr(start_element, "id"):
starts.append(unicode(start_element.id))
else:
starts.append(unicode(start_element))
else:
starts = u"*"
start = u"rel(%s)" % u", ".join(starts)
if not isinstance(lookups, (list, tuple)):
lookups = [lookups]
types = {
"node": Node,
"relationship": Relationship,
"path": Path,
"position": Position,
}
return FilterSequence(self._cypher, self._auth, start=start,
types=types, lookups=lookups, returns=Node)
else:
raise CypherException

def _indexes(self):
if self._relationship_index:
return IndexesProxy(self._relationship_index, RELATIONSHIP,
Expand Down
3 changes: 3 additions & 0 deletions neo4jrestclient/constants.py
Expand Up @@ -42,6 +42,9 @@
INDEX_RELATIONSHIP = "index_relationship"
INDEX_EXACT = "exact"
INDEX_FULLTEXT = "fulltext"
# Cypher ordering
ASC = "asc"
DESC = "desc"
# Transactions
TX_GET = "GET"
TX_PUT = "PUT"
Expand Down
92 changes: 55 additions & 37 deletions neo4jrestclient/query.py
Expand Up @@ -27,7 +27,7 @@ class BaseQ(object):
"eq", "equals", "neq", "notequals")

def __init__(self, property=None, lookup=None, match=None,
nullable=None, var=u"n", **kwargs):
nullable=True, var=u"n", **kwargs):
self._and = None
self._or = None
self._not = None
Expand Down Expand Up @@ -137,16 +137,16 @@ def _get_lookup_and_match(self):
match = u"(?i){0}".format(self.match)
elif self.lookup == "gt":
lookup = u">"
match = u"{0}".format(self.match)
match = self.match
elif self.lookup == "gte":
lookup = u">"
match = u"{0}".format(self.match)
lookup = u">="
match = self.match
elif self.lookup == "lt":
lookup = u"<"
match = u"{0}".format(self.match)
match = self.match
elif self.lookup == "lte":
lookup = u"<"
match = u"{0}".format(self.match)
lookup = u"<="
match = self.match
elif self.lookup in ["in", "inrange"]:
lookup = u"IN"
match = u"['{0}']".format(u"', '".join([self._escape(m)
Expand Down Expand Up @@ -234,14 +234,16 @@ class CypherException(Exception):
pass


class Query(Sequence):
class QuerySequence(Sequence):

def __init__(self, cypher, auth, q, params=None, types=None, returns=None):
self.q = q
self.params = params
self.skip = None
self.limit = None
self.returns = returns
self._skip = None
self._limit = None
self._order_by = None
self._returns = returns
self._return_single_rows = False
self._auth = auth
self._cypher = cypher
# This way we avoid a circular reference, by passing objects like Node
Expand All @@ -253,7 +255,7 @@ def _get_elements(self):
response = self.get_response()
try:
self._elements = self.cast(elements=response["data"],
returns=self.returns)
returns=self._returns)
except:
self._elements = response
return self._elements
Expand All @@ -275,14 +277,32 @@ def __reversed__(self):
return reversed(self.elements)

def get_response(self):
# Preparing slicing and ordering
q = self.q
params = self.params
if isinstance(self.skip, int) and "_skip" not in params:
if self._order_by:
orders = []
for o, order in enumerate(self._order_by):
order_key = "_order_by_%s" % o
if order_key not in params:
nullable = ""
if len(order) == 3:
if order[2] is True:
nullable = "!"
elif order[2] is False:
nullable = "?"
orders.append(u"n.`{%s}`%s %s" % (order_key, nullable, order[1]))
params[order_key] = order[0]
if orders:
q = u"%s order by %s" % (q, ", ".join(orders))
# Lazy slicing
if isinstance(self._skip, int) and "_skip" not in params:
q = u"%s skip {_skip} " % q
params["_skip"] = self.skip
if isinstance(self.limit, int) and "_limit" not in params:
params["_skip"] = self._skip
if isinstance(self._limit, int) and "_limit" not in params:
q = u"%s limit {_limit} " % q
params["_limit"] = self.limit
params["_limit"] = self._limit
# Making the real resquest
data = {
"query": q,
"params": params,
Expand Down Expand Up @@ -342,11 +362,14 @@ def cast(self, elements, returns=None):
casted_row.append(sub_func(element))
else:
casted_row.append(func(element))
results.append(casted_row)
if self._return_single_rows:
results.append(*casted_row)
else:
results.append(casted_row)
return results


class Filter(Query):
class FilterSequence(QuerySequence):

def __init__(self, cypher, auth, start=None, lookups=[],
order_by=None, types=None, returns=None):
Expand All @@ -366,25 +389,20 @@ def __init__(self, cypher, auth, start=None, lookups=[],
q = u"%s where %s return n " % (q, where)
else:
q = u"%s return n " % q
if order_by:
if not isinstance(order_by, (tuple, list)):
order_by = [order_by]
orders = []
for o, order in enumerate(order_by):
if order.startswith(u"-"):
orders.append(u"n.`{order_by_%s}` desc" % o)
else:
orders.append(u"n.`{order_by_%s}` " % o)
params["order_by_%s" % o] = order
if orders:
q = u"%s order by %s" % (q, ", ".join(orders))
params["start"] = start
super(Filter, self).__init__(cypher=cypher, auth=auth, q=q,
params=params, types=types,
returns=returns)
super(FilterSequence, self).__init__(cypher=cypher, auth=auth, q=q,
params=params, types=types,
returns=returns)
self._return_single_rows = True

def __getitem__(self, key):
if isinstance(key, slice):
self.skip = key.start
self.limit = key.stop
return super(Filter, self).__getitem__(key)
self._skip = key.start
self._limit = key.stop
return super(FilterSequence, self).__getitem__(key)

def order_by(self, property=None, type=None, nullable=True, *args):
if property is None and isinstance(args, (list, tuple)):
self._order_by = args
else:
self._order_by = [(property, type, nullable)]
return self
56 changes: 47 additions & 9 deletions neo4jrestclient/tests.py
Expand Up @@ -1258,7 +1258,7 @@ def test_relationship_pickle(self):
self.assertEqual(r, pickle.loads(p))


class QueryAndFilterTestCase(NodesTestCase):
class QueryTestCase(PickleTestCase):

def test_query_raw(self):
n1 = self.gdb.nodes.create(name="John")
Expand Down Expand Up @@ -1313,11 +1313,14 @@ def test_query_params_returns_tuple(self):
self.assertEqual(rel, r)
self.assertEqual(date, 1982)


class FilterTestCase(QueryTestCase):

def test_filter_nodes(self):
Q = query.Q
for i in range(5):
self.gdb.nodes.create(name="William %s" % i)
lookup = Q("name", istartswith="william", nullable=True)
lookup = Q("name", istartswith="william")
williams = self.gdb.nodes.filter(lookup)
self.assertTrue(len(williams) >= 5)

Expand All @@ -1326,24 +1329,59 @@ def test_filter_nodes_complex_lookups(self):
for i in range(5):
self.gdb.nodes.create(name="James", surname="Smith %s" % i)
lookups = (
Q("name", exact="James", nullable=True) &
(Q("surname", startswith="Smith", nullable=True) &
~Q("surname", endswith="1", nullable=True))
Q("name", exact="James") &
(Q("surname", startswith="Smith") &
~Q("surname", endswith="1"))
)
williams = self.gdb.nodes.filter(lookups)
self.assertTrue(len(williams) >= 5)

def test_filter_slicing(self):
def test_filter_nodes_slicing(self):
Q = query.Q
for i in range(5):
self.gdb.nodes.create(name="William %s" % i)
lookup = Q("name", istartswith="william", nullable=True)
import ipdb; ipdb.set_trace()
lookup = Q("name", istartswith="william")
williams = self.gdb.nodes.filter(lookup)[:4]
self.assertTrue(len(williams) == 4)

def test_filter_nodes_ordering(self):
Q = query.Q
for i in range(5):
self.gdb.nodes.create(name="William", code=i)
lookup = Q("code", gte=2)
williams = self.gdb.nodes.filter(lookup).order_by("code",
constants.DESC)
self.assertTrue(williams[-1]["code"] > williams[0]["code"])

def test_filter_nodes_nullable(self):
Q = query.Q
for i in range(5):
self.gdb.nodes.create(name="William %s" % i)
lookup = Q("name", istartswith="william", nullable=False)
williams = self.gdb.nodes.filter(lookup)[:10]
self.assertTrue(len(williams) > 5)

def test_filter_nodes_start(self):
Q = query.Q
nodes = []
for i in range(5):
nodes.append(self.gdb.nodes.create(name="William %s" % i))
lookup = Q("name", istartswith="william")
williams = self.gdb.nodes.filter(lookup, start=nodes)
self.assertTrue(len(williams) == 5)

def test_filter_relationships(self):
Q = query.Q
for i in range(10):
n1 = self.gdb.nodes.create(name="William %s" % i)
n2 = self.gdb.nodes.create(name="Rose %s" % i)
n1.loves(n2, since=(1995 + i))
lookup = Q("since", lte=2000)
old_loves = self.gdb.relationships.filter(lookup)
self.assertTrue(len(old_loves) >= 5)


class Neo4jPythonClientTestCase(QueryAndFilterTestCase):
class Neo4jPythonClientTestCase(FilterTestCase):
pass


Expand Down

0 comments on commit a1f76e9

Please sign in to comment.