Skip to content

Commit 35a3358

Browse files
Adding save/load feature for pipeline, semantics are as follows:
1. Save load on an untrained pipeline will save the graph 2. On a trained/selected pipeline, will save the node state 3. Tests for simple graph save have been added
1 parent c4c7990 commit 35a3358

File tree

5 files changed

+121
-1
lines changed

5 files changed

+121
-1
lines changed

codeflare/pipelines/Datamodel.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from sklearn.base import BaseEstimator
77

88
import ray
9+
import pickle5 as pickle
910
import codeflare.pipelines.Exceptions as pe
1011

1112
class Xy:
@@ -103,6 +104,9 @@ def __init__(self, node_name, node_input_type: NodeInputType, node_firing_type:
103104
def __str__(self):
104105
return self.__node_name__
105106

107+
def get_node_name(self):
108+
return self.__node_name__
109+
106110
def get_node_input_type(self):
107111
return self.__node_input_type__
108112

@@ -379,6 +383,52 @@ def get_terminal_nodes(self):
379383
terminal_nodes.append(node)
380384
return terminal_nodes
381385

386+
def save(self, filehandle):
387+
nodes = {}
388+
edges = []
389+
390+
for node in self.__pre_graph__.keys():
391+
nodes[node.get_node_name()] = node
392+
pre_edges = self.get_pre_edges(node)
393+
for edge in pre_edges:
394+
# Since we are iterating on pre_edges, to_node cannot be None
395+
from_node = edge.get_from_node()
396+
if from_node is not None:
397+
to_node = edge.get_to_node()
398+
edge_tuple = (from_node.get_node_name(), to_node.get_node_name())
399+
edges.append(edge_tuple)
400+
saved_pipeline = _SavedPipeline(nodes, edges)
401+
pickle.dump(saved_pipeline, filehandle)
402+
403+
@staticmethod
404+
def load(filehandle):
405+
saved_pipeline = pickle.load(filehandle)
406+
if not isinstance(saved_pipeline, _SavedPipeline):
407+
raise pe.PipelineException("Filehandle is not a saved pipeline instance")
408+
409+
nodes = saved_pipeline.get_nodes()
410+
edges = saved_pipeline.get_edges()
411+
412+
pipeline = Pipeline()
413+
for edge in edges:
414+
(from_node_str, to_node_str) = edge
415+
from_node = nodes[from_node_str]
416+
to_node = nodes[to_node_str]
417+
pipeline.add_edge(from_node, to_node)
418+
return pipeline
419+
420+
421+
class _SavedPipeline:
422+
def __init__(self, nodes, edges):
423+
self.__nodes__ = nodes
424+
self.__edges__ = edges
425+
426+
def get_nodes(self):
427+
return self.__nodes__
428+
429+
def get_edges(self):
430+
return self.__edges__
431+
382432

383433
class PipelineOutput:
384434
"""

codeflare/pipelines/Runtime.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from enum import Enum
99

1010
from queue import SimpleQueue
11+
import pickle5 as pickle
1112

1213

1314
class ExecutionType(Enum):
@@ -265,3 +266,8 @@ def cross_validate(cross_validator: BaseCrossValidator, pipeline: dm.Pipeline, p
265266
result_scores.append(out_x)
266267

267268
return result_scores
269+
270+
271+
def save(pipeline_output: dm.PipelineOutput, xy_ref: dm.XYRef, filehandle):
272+
pipeline = select_pipeline(pipeline_output, xy_ref)
273+
pipeline.save(filehandle)

codeflare/pipelines/tests/__init__.py

Whitespace-only changes.
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import pytest
2+
3+
import codeflare.pipelines.Datamodel as dm
4+
import codeflare.pipelines.Runtime as rt
5+
6+
import numpy as np
7+
from sklearn.preprocessing import FunctionTransformer
8+
from sklearn.preprocessing import MinMaxScaler
9+
import os
10+
11+
12+
class FeatureUnion(dm.AndTransform):
13+
def __init__(self):
14+
pass
15+
16+
def transform(self, xy_list):
17+
X_list = []
18+
y_list = []
19+
20+
for xy in xy_list:
21+
X_list.append(xy.get_x())
22+
X_concat = np.concatenate(X_list, axis=0)
23+
24+
return dm.Xy(X_concat, None)
25+
26+
27+
def test_save_load():
28+
"""
29+
A simple save load test for a pipeline graph
30+
:return:
31+
"""
32+
pipeline = dm.Pipeline()
33+
minmax_scaler = MinMaxScaler()
34+
35+
node_a = dm.EstimatorNode('a', minmax_scaler)
36+
node_b = dm.EstimatorNode('b', minmax_scaler)
37+
node_c = dm.AndNode('c', FeatureUnion())
38+
39+
pipeline.add_edge(node_a, node_c)
40+
pipeline.add_edge(node_b, node_c)
41+
42+
fname = 'save_pipeline.cfp'
43+
fh = open(fname, 'wb')
44+
pipeline.save(fh)
45+
fh.close()
46+
47+
r_fh = open(fname, 'rb')
48+
saved_pipeline = dm.Pipeline.load(r_fh)
49+
pre_edges = saved_pipeline.get_pre_edges(node_c)
50+
assert(len(pre_edges) == 2)
51+
52+
os.remove(fname)
53+
54+
55+
def test_runtime_save_load():
56+
"""
57+
Tests for selecting a pipeline and save/load it, we also test the predict to ensure state is
58+
captured accurately
59+
:return:
60+
"""
61+

requirements.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,7 @@ ray~=1.3.0
22
setuptools~=52.0.0
33
sklearn~=0.0
44
scikit-learn~=0.24.1
5-
pandas~=1.2.4
5+
pandas~=1.2.4
6+
pytest~=6.2.4
7+
numpy~=1.18.5
8+
pickle5~=0.0.11

0 commit comments

Comments
 (0)