Permalink
Browse files

Adding support for retrieving index elements in a transaction

  • Loading branch information...
versae committed Jul 13, 2012
1 parent 851144d commit 73266a08c0eda021b7ca84000cf310b4480b3297
Showing with 57 additions and 21 deletions.
  1. +39 −20 neo4jrestclient/client.py
  2. +18 −1 neo4jrestclient/tests.py
View
@@ -224,11 +224,13 @@ def __getattribute__(self, attr, *args, **kwargs):
"This is needed in order to handle pickling in "
"nodes.",
DeprecationWarning)
+
def _create_relationship(*args, **kwargs):
_attr = "_create_relationship"
_func = object.__getattribute__(self, _attr)
_relationship = _func(attr)
return _relationship(*args, **kwargs)
+
return _create_relationship
else:
return object.__getattribute__(self, attr)
@@ -250,7 +252,7 @@ def _get_properties(self):
return dict(self)
def _set_properties(self, props={}):
- _type = object.__getattribute__(self, "_extras")["type"]
+ # _type = object.__getattribute__(self, "_extras")["type"]
if type == RELATIONSHIP:
_body = dict.__getitem__(self, "body")
_data = dict.__getitem__(_body, "data")
@@ -260,7 +262,7 @@ def _set_properties(self, props={}):
_body.update(props)
def _del_properties(self):
- _type = object.__getattribute__(self, "_extras")["type"]
+ # _type = object.__getattribute__(self, "_extras")["type"]
if type == RELATIONSHIP:
_body = dict.__getitem__(self, "body")
dict.__setitem__(_body, "data", {})
@@ -291,7 +293,7 @@ def __getitem__(self, key):
eltos = _proxy._list[key]
if _proxy._attribute:
return [_proxy._class(elto[_proxy._attribute],
- update_dict=elto) for elto in eltos]
+ update_dict=elto) for elto in eltos]
else:
return [_proxy._class(elto) for elto in eltos]
else:
@@ -300,11 +302,14 @@ def __getitem__(self, key):
if _type == RELATIONSHIP:
_body = dict.__getitem__(self, "body")
return dict.__getitem__(_body, key)
+ elif _type == ITERABLE:
+ _body = dict.__getitem__(self, "body")
+ dict.__setitem__(self, "key", key)
+ return self
else:
_body = dict.__getitem__(self, "body")
return dict.__getitem__(_body, key)
-
def __setitem__(self, key, val):
_type = object.__getattribute__(self, "_extras")["type"]
_proxy = object.__getattribute__(self, "_proxy")
@@ -319,7 +324,6 @@ def __setitem__(self, key, val):
_body = dict.__getitem__(self, "body")
return dict.__setitem__(_body, key, val)
-
def get_object(self):
if self._object_ref:
return self._object_ref()
@@ -346,16 +350,29 @@ def change(self, cls, url, data=None, auth=None):
if not data["body"] or len(data["body"]) == 0:
self._proxy = list()
else:
- first_element = data["body"][0]
- if "self" in first_element:
- if NODE in first_element["self"]:
- self._proxy = cls(Node, data["body"], "self",
- auth=_auth)
- elif RELATIONSHIP in first_element["self"]:
- self._proxy = cls(Relationship, data["body"], "self",
- auth=_auth)
+ self_keys = self().keys()
+ # Check if it is an element from the iterable
+ if "key" in self_keys and "of" in self_keys:
+ _key = dict.__getitem__(self, "key")
+ _of = dict.__getitem__(self, "of")
+ _body = data["body"][_key]
+ if _of == NODE:
+ cls = Node
+ else: # Relationship
+ cls = Relationship
+ self._proxy = cls(_body["self"],
+ update_dict=_body, auth=_auth)
else:
- self._proxy = cls(Path, data["body"], auth=_auth)
+ first_element = data["body"][0]
+ if "self" in first_element:
+ if NODE in first_element["self"]:
+ self._proxy = cls(Node, data["body"], "self",
+ auth=_auth)
+ elif RELATIONSHIP in first_element["self"]:
+ self._proxy = cls(Relationship, data["body"],
+ "self", auth=_auth)
+ else:
+ self._proxy = cls(Path, data["body"], auth=_auth)
else:
if "self" in data["body"] and data["body"]["self"] != url:
self._proxy = cls(data["body"]["self"],
@@ -465,7 +482,6 @@ def add(self, key, value, item, tx=None):
return op
-
class Transaction(object):
"""
Transaction class.
@@ -576,7 +592,8 @@ def commit(self, *args, **kwargs):
else:
return True
- def subscribe(self, method, url, data=None, obj=None, returns=None):
+ def subscribe(self, method, url, data=None, obj=None, returns=None,
+ of=None):
job_id = len(self.operations)
if url.startswith("{"):
url_to = url
@@ -588,6 +605,7 @@ def subscribe(self, method, url, data=None, obj=None, returns=None):
"method": method,
"to": url_to,
"id": job_id,
+ "of": of,
}
if data:
params.update({"body": data})
@@ -1031,7 +1049,7 @@ def items(self):
return self._dic["data"].viewitems()
except AttributeError:
return self._dic["data"].items()
-
+
def traverse(self, types=None, order=None, stop=None, returnable=None,
uniqueness=None, is_stop_node=None, is_returnable=None,
paginated=False, page_size=None, time_out=None,
@@ -1290,7 +1308,8 @@ class Index(object):
def _get_results(url, node_or_rel, auth={}, tx=None):
tx = Transaction.get_transaction(tx)
if tx:
- return tx.subscribe(TX_GET, url, obj=None, returns=ITERABLE)
+ return tx.subscribe(TX_GET, url, obj=None, returns=ITERABLE,
+ of=node_or_rel)
else:
response, content = Request(**auth).get(url)
if response.status == 200:
@@ -1598,7 +1617,7 @@ def __init__(self, node, auth=None):
self._len = 0
def __getattr__(self, relationship_type):
- auth = object.__getattribute__(self, "_auth")
+ # auth = object.__getattribute__(self, "_auth")
def get_relationships(types=None, *args, **kwargs):
tx = Transaction.get_transaction(kwargs.get("tx", None))
@@ -1802,12 +1821,12 @@ def __getattr__(self, attr):
'type': property(lambda self: attr),
})()
-
All = BaseInAndOut(direction=RELATIONSHIPS_ALL)
Incoming = BaseInAndOut(direction=RELATIONSHIPS_IN)
Outgoing = BaseInAndOut(direction=RELATIONSHIPS_OUT)
Undirected = BaseInAndOut(direction="both") # Deprecated, use "All" instead
+
class Direction(object):
ANY = All
INCOMING = Incoming
View
@@ -1041,7 +1041,7 @@ def test_transaction_index_query(self):
tx = self.gdb.transaction()
index_hits = index['test2']['test2']
tx.commit()
- self.assertTrue(n1 == index_hits[:][-1])
+ self.assertTrue(n1 == index_hits[-1])
def test_transaction_remove_node_from_index(self):
index = self.gdb.nodes.indexes.create('index3')
@@ -1182,6 +1182,23 @@ def test_transaction_access_node(self):
rel = frame.FRAME_EDGE(edge)
self.assertTrue(isinstance(rel, client.Relationship))
+ # Test from http://stackoverflow.com/questions/11407546/
+ def test_a_transaction_index_access_create_relationship(self):
+ s = self.gdb.node.create(id=1)
+ d = self.gdb.node.create(id=2)
+ nidx = self.gdb.nodes.indexes.create('nodelist')
+ nidx.add('nid',1, s)
+ nidx.add('nid',2, d)
+ nodelist = [(1, 2)]
+ with self.gdb.transaction():
+ for s_id, d_id in nodelist:
+ sn = nidx['nid'][s_id][-1]
+ dn = nidx['nid'][d_id][-1]
+# rel = sn.Follows(dn)
+# self.assertTrue(isinstance(rel, client.Relationship))
+ self.assertEqual(s, sn)
+ self.assertEqual(d, dn)
+
class PickleTestCase(TransactionsTestCase):

0 comments on commit 73266a0

Please sign in to comment.