-
Notifications
You must be signed in to change notification settings - Fork 52
/
Copy pathgraphsage_model.py
54 lines (40 loc) · 1.68 KB
/
graphsage_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from typing import Any
from pandas import Series
from ..call_parameters import CallParameters
from ..graph.graph_object import Graph
from ..graph.graph_type_check import graph_type_check
from .model import Model
class GraphSageModel(Model):
"""
Represents a GraphSAGE model in the model catalog.
Construct this using :func:`gds.beta.graphSage.train()`.
"""
def _endpoint_prefix(self) -> str:
return "gds.beta.graphSage."
@graph_type_check
def predict_write(self, G: Graph, **config: Any) -> "Series[Any]":
"""
Generate embeddings for the given graph and write the results to the database.
Args:
G: The graph to generate embeddings for.
**config: The config for the prediction.
Returns:
The result of the write operation.
"""
endpoint = self._endpoint_prefix() + "write"
config["modelName"] = self.name()
params = CallParameters(graph_name=G.name(), config=config)
return self._query_runner.call_procedure( # type: ignore
endpoint=endpoint, params=params, logging=True
).squeeze()
@graph_type_check
def predict_write_estimate(self, G: Graph, **config: Any) -> "Series[Any]":
"""
Estimate the memory needed to generate embeddings for the given graph and write the results to the database.
Args:
G: The graph to generate embeddings for.
**config: The config for the prediction.
Returns:
The memory needed to generate embeddings for the given graph and write the results to the database.
"""
return self._estimate_predict("write", G.name(), config)