-
Notifications
You must be signed in to change notification settings - Fork 52
/
Copy pathnode_classification_model.py
76 lines (56 loc) · 2.32 KB
/
node_classification_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from typing import Any
from pandas import Series
from ..call_parameters import CallParameters
from ..graph.graph_object import Graph
from ..graph.graph_type_check import graph_type_check
from .pipeline_model import PipelineModel
class NCModel(PipelineModel):
"""
Represents a node classification model in the model catalog.
Construct this using
:func:`NCTrainingPipeline.train() <graphdatascience.pipeline.nc_training_pipeline.NCTrainingPipeline.train>`.
"""
def _endpoint_prefix(self) -> str:
return "gds.beta.pipeline.nodeClassification.predict."
@graph_type_check
def predict_write(self, G: Graph, **config: Any) -> "Series[Any]":
"""
Predict the node labels of a graph and write the results to the database.
Args:
G: The graph to predict on.
**config: The config for the prediction.
Returns:
The result of the write operation.
"""
endpoint = f"{self._endpoint_prefix()}write"
config["modelName"] = self.name()
params = CallParameters(graph_name=G.name(), config=config)
return self._query_runner.call_procedure( # type: ignore
endpoint=endpoint, params=params, logging=True
).squeeze()
@graph_type_check
def predict_write_estimate(self, G: Graph, **config: Any) -> "Series[Any]":
"""
Estimate the memory needed to predict the node labels of a graph and write the results to the database.
Args:
G: The graph to predict on.
**config: The config for the prediction.
Returns:
The memory needed to predict the node labels of a graph and write the results to the database.
"""
return self._estimate_predict("write", G.name(), config)
def classes(self) -> list[int]:
"""
Get the classes of the model.
Returns:
The classes of the model.
"""
return self._list_info()["modelInfo"]["classes"] # type: ignore
def feature_properties(self) -> list[str]:
"""
Get the feature properties of the model.
Returns:
The feature properties of the model.
"""
features: list[dict[str, Any]] = self._list_info()["modelInfo"]["pipeline"]["featureProperties"]
return [f["feature"] for f in features]