Skip to content

Commit 5f708bb

Browse files
committed
Support gds.util.asNode(s) in Sessions
1 parent a7136a8 commit 5f708bb

File tree

10 files changed

+172
-37
lines changed

10 files changed

+172
-37
lines changed

graphdatascience/call_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .algo.algo_endpoints import AlgoEndpoints
22
from .error.uncallable_namespace import UncallableNamespace
3-
from .utils.util_endpoints import IndirectUtilAlphaEndpoints
3+
from .utils.direct_util_endpoints import IndirectUtilAlphaEndpoints
44

55

66
class IndirectCallBuilder(AlgoEndpoints, UncallableNamespace):

graphdatascience/endpoints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
SystemBetaEndpoints,
2424
)
2525
from .topological_lp.topological_lp_endpoints import TopologicalLPAlphaEndpoints
26-
from .utils.util_endpoints import DirectUtilEndpoints
26+
from .utils.direct_util_endpoints import DirectUtilEndpoints
2727

2828
"""
2929
This class should inherit endpoint classes that only contain endpoints that can be called directly from

graphdatascience/graph_data_science.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from .query_runner.query_runner import QueryRunner
1414
from .server_version.server_version import ServerVersion
1515
from graphdatascience.graph.graph_proc_runner import GraphProcRunner
16+
from graphdatascience.utils.util_proc_runner import UtilProcRunner
1617

1718

1819
class GraphDataScience(DirectEndpoints, UncallableNamespace):
@@ -81,12 +82,16 @@ def __init__(
8182
None if arrow is True else arrow,
8283
)
8384

84-
super().__init__(self._query_runner, "gds", self._server_version)
85+
super().__init__(self._query_runner, namespace="gds", server_version=self._server_version)
8586

8687
@property
8788
def graph(self) -> GraphProcRunner:
8889
return GraphProcRunner(self._query_runner, f"{self._namespace}.graph", self._server_version)
8990

91+
@property
92+
def util(self) -> UtilProcRunner:
93+
return UtilProcRunner(self._query_runner, f"{self._namespace}.util", self._server_version)
94+
9095
@property
9196
def alpha(self) -> AlphaEndpoints:
9297
return AlphaEndpoints(self._query_runner, "gds.alpha", self._server_version)

graphdatascience/tests/integration/test_util_ops.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from graphdatascience.graph_data_science import GraphDataScience
77
from graphdatascience.query_runner.neo4j_query_runner import Neo4jQueryRunner
88
from graphdatascience.server_version.server_version import ServerVersion
9+
from graphdatascience.session.aura_graph_data_science import AuraGraphDataScience
910

1011

1112
@pytest.fixture(autouse=True)
@@ -97,6 +98,14 @@ def test_util_asNode(gds: GraphDataScience) -> None:
9798
assert result["name"] == "A"
9899

99100

101+
@pytest.mark.cloud_architecture
102+
@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 7, 0))
103+
def test_remote_util_as_node(gds_with_cloud_setup: AuraGraphDataScience) -> None:
104+
id = gds_with_cloud_setup.find_node_id(["Location"], {"name": "A"})
105+
result = gds_with_cloud_setup.util.asNode(id)
106+
assert result["name"] == "A"
107+
108+
100109
def test_util_asNodes(gds: GraphDataScience) -> None:
101110
ids = [
102111
gds.find_node_id(["Location"], {"name": "A"}),
@@ -106,6 +115,17 @@ def test_util_asNodes(gds: GraphDataScience) -> None:
106115
assert len(result) == 2
107116

108117

118+
@pytest.mark.cloud_architecture
119+
@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 7, 0))
120+
def test_remote_util_as_nodes(gds_with_cloud_setup: AuraGraphDataScience) -> None:
121+
ids = [
122+
gds_with_cloud_setup.find_node_id(["Location"], {"name": "A"}),
123+
gds_with_cloud_setup.find_node_id(["Location"], {"name": 2}),
124+
]
125+
result = gds_with_cloud_setup.util.asNodes(ids)
126+
assert len(result) == 2
127+
128+
109129
def test_util_nodeProperty(gds: GraphDataScience, G: Graph) -> None:
110130
id = gds.find_node_id(["Location"], {"name": "A"})
111131
result = gds.util.nodeProperty(G, id, "population")

graphdatascience/tests/unit/conftest.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,13 @@ def call_procedure(
4444

4545
return self.run_cypher(query, params, database, custom_error)
4646

47+
def call_function(self, endpoint: str, params: Optional[CallParameters] = None) -> Any:
48+
if params is None:
49+
params = CallParameters()
50+
query = f"RETURN {endpoint}({params.placeholder_str()})"
51+
52+
return self.run_cypher(query, params).squeeze()
53+
4754
def run_cypher(
4855
self, query: str, params: Optional[Dict[str, Any]] = None, db: Optional[str] = None, custom_error: bool = True
4956
) -> DataFrame:

graphdatascience/tests/unit/test_util_ops.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
from graphdatascience.graph.graph_object import Graph
12
from graphdatascience.graph_data_science import GraphDataScience
3+
from graphdatascience.session.aura_graph_data_science import AuraGraphDataScience
24
from graphdatascience.tests.unit.conftest import CollectingQueryRunner
35

46

@@ -11,3 +13,55 @@ def test_list(runner: CollectingQueryRunner, gds: GraphDataScience) -> None:
1113

1214
assert runner.last_query() == "CALL gds.list()"
1315
assert runner.last_params() == {}
16+
17+
18+
def test_as_node(runner: CollectingQueryRunner, gds: GraphDataScience) -> None:
19+
gds.util.asNode(1)
20+
21+
assert runner.last_query() == "RETURN gds.util.asNode(1) AS node"
22+
23+
24+
def test_remote_as_node(runner: CollectingQueryRunner, aura_gds: AuraGraphDataScience) -> None:
25+
aura_gds.util.asNode(1)
26+
27+
assert runner.last_query() == "MATCH (n) WHERE id(n) = $nodeId RETURN n"
28+
assert runner.last_params() == {"nodeId": 1}
29+
30+
31+
def test_as_nodes(runner: CollectingQueryRunner, gds: GraphDataScience) -> None:
32+
gds.util.asNodes([1, 2, 3])
33+
34+
assert runner.last_query() == "RETURN gds.util.asNodes([1, 2, 3]) AS nodes"
35+
36+
37+
def test_remote_as_nodes(runner: CollectingQueryRunner, aura_gds: AuraGraphDataScience) -> None:
38+
aura_gds.util.asNodes([1, 2, 3])
39+
40+
assert runner.last_query() == "MATCH (n) WHERE id(n) IN $nodeIds RETURN collect(n)"
41+
assert runner.last_params() == {"nodeId": [1, 2, 3]}
42+
43+
44+
def test_node_property(runner: CollectingQueryRunner, gds: GraphDataScience) -> None:
45+
G = Graph("g", runner, gds._server_version)
46+
gds.util.nodeProperty(G, 1, "my_prop", "my_label")
47+
48+
assert runner.last_query() == "RETURN gds.util.nodeProperty($graph_name, $node_id, $property_key, $node_label)"
49+
assert runner.last_params() == {
50+
"graph_name": "g",
51+
"node_id": 1,
52+
"property_key": "my_prop",
53+
"node_label": "my_label",
54+
}
55+
56+
57+
def test_remote_node_property(runner: CollectingQueryRunner, aura_gds: AuraGraphDataScience) -> None:
58+
G = Graph("g", runner, aura_gds._server_version)
59+
aura_gds.util.nodeProperty(G, 1, "my_prop", "my_label")
60+
61+
assert runner.last_query() == "RETURN gds.util.nodeProperty($graph_name, $node_id, $property_key, $node_label)"
62+
assert runner.last_params() == {
63+
"graph_name": "g",
64+
"node_id": 1,
65+
"property_key": "my_prop",
66+
"node_label": "my_label",
67+
}

graphdatascience/utils/util_endpoints.py renamed to graphdatascience/utils/direct_util_endpoints.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
from ..caller_base import CallerBase
66
from ..error.client_only_endpoint import client_only_endpoint
7-
from .util_proc_runner import UtilProcRunner
87
from graphdatascience.call_parameters import CallParameters
98
from graphdatascience.error.cypher_warning_handler import (
109
filter_id_func_deprecation_warning,
@@ -86,10 +85,6 @@ def list(self) -> DataFrame:
8685
namespace = self._namespace + ".list"
8786
return self._query_runner.call_procedure(endpoint=namespace, custom_error=False)
8887

89-
@property
90-
def util(self) -> UtilProcRunner:
91-
return UtilProcRunner(self._query_runner, f"{self._namespace}.util", self._server_version)
92-
9388

9489
class IndirectUtilAlphaEndpoints(CallerBase):
9590
def oneHotEncoding(self, available_values: List[Any], selected_values: List[Any]) -> List[int]:
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from typing import Any
2+
3+
from graphdatascience.call_parameters import CallParameters
4+
from graphdatascience.error.illegal_attr_checker import IllegalAttrChecker
5+
from graphdatascience.graph.graph_object import Graph
6+
from graphdatascience.graph.graph_type_check import graph_type_check
7+
8+
9+
class NodePropertyFuncRunner(IllegalAttrChecker):
10+
@graph_type_check
11+
def __call__(self, G: Graph, node_id: int, property_key: str, node_label: str = "*") -> Any:
12+
"""
13+
Get the property of a node with the given id.
14+
15+
Args:
16+
G: The graph to get the node property from.
17+
node_id: The id of the node to get the property from.
18+
property_key: The key of the property to get.
19+
node_label: The label of the node to get the property from.
20+
21+
Returns:
22+
The property of the node with the given id.
23+
24+
"""
25+
params = CallParameters(
26+
graph_name=G.name(),
27+
node_id=node_id,
28+
property_key=property_key,
29+
node_label=node_label,
30+
)
31+
return self._query_runner.call_function(endpoint=self._namespace, params=params)

graphdatascience/utils/util_proc_runner.py

Lines changed: 4 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22

33
from ..error.illegal_attr_checker import IllegalAttrChecker
44
from ..error.uncallable_namespace import UncallableNamespace
5-
from ..graph.graph_object import Graph
6-
from ..graph.graph_type_check import graph_type_check
5+
from graphdatascience.utils.util_node_property_func_runner import NodePropertyFuncRunner
76

87

98
class UtilProcRunner(UncallableNamespace, IllegalAttrChecker):
@@ -39,30 +38,6 @@ def asNodes(self, node_ids: List[int]) -> List[Any]:
3938

4039
return result.iat[0, 0] # type: ignore
4140

42-
@graph_type_check
43-
def nodeProperty(self, G: Graph, node_id: int, property_key: str, node_label: str = "*") -> Any:
44-
"""
45-
Get the property of a node with the given id.
46-
47-
Args:
48-
G: The graph to get the node property from.
49-
node_id: The id of the node to get the property from.
50-
property_key: The key of the property to get.
51-
node_label: The label of the node to get the property from.
52-
53-
Returns:
54-
The property of the node with the given id.
55-
56-
"""
57-
self._namespace += ".nodeProperty"
58-
59-
query = f"RETURN {self._namespace}($graph_name, $node_id, $property_key, $node_label) as property"
60-
params = {
61-
"graph_name": G.name(),
62-
"node_id": node_id,
63-
"property_key": property_key,
64-
"node_label": node_label,
65-
}
66-
result = self._query_runner.run_cypher(query, params)
67-
68-
return result.iat[0, 0]
41+
@property
42+
def nodeProperty(self) -> NodePropertyFuncRunner:
43+
return NodePropertyFuncRunner(self._query_runner, self._namespace + ".nodeProperty", self._server_version)
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from typing import Any, List
2+
3+
from ..error.illegal_attr_checker import IllegalAttrChecker
4+
from ..error.uncallable_namespace import UncallableNamespace
5+
from graphdatascience.error.cypher_warning_handler import (
6+
filter_id_func_deprecation_warning,
7+
)
8+
from graphdatascience.utils.util_node_property_func_runner import NodePropertyFuncRunner
9+
10+
11+
class UtilRemoteProcRunner(UncallableNamespace, IllegalAttrChecker):
12+
@filter_id_func_deprecation_warning()
13+
def asNode(self, node_id: int) -> Any:
14+
"""
15+
Get a node from a node id.
16+
17+
Args:
18+
node_id: The id of the node to get.
19+
20+
Returns:
21+
The node with the given id.
22+
23+
"""
24+
query = "MATCH (n) WHERE id(n) = $nodeId RETURN n"
25+
params = {"nodeId": node_id}
26+
27+
return self._query_runner.run_cypher(query=query, params=params).squeeze()
28+
29+
@filter_id_func_deprecation_warning()
30+
def asNodes(self, node_ids: List[int]) -> List[Any]:
31+
"""
32+
Get a list of nodes from a list of node ids.
33+
34+
Args:
35+
node_ids: The ids of the nodes to get.
36+
37+
Returns:
38+
The nodes with the given ids.
39+
40+
"""
41+
query = "MATCH (n) WHERE id(n) IN $nodeIds RETURN collect(n)"
42+
params = {"nodeId": node_ids}
43+
44+
return self._query_runner.run_cypher(query=query, params=params).squeeze() # type: ignore
45+
46+
@property
47+
def nodeProperty(self) -> NodePropertyFuncRunner:
48+
return NodePropertyFuncRunner(self._query_runner, self._namespace + ".nodeProperty", self._server_version)

0 commit comments

Comments
 (0)