Skip to content

Commit d6a737e

Browse files
committed
test codeflare pipeline with and node
1 parent a5ade55 commit d6a737e

File tree

1 file changed

+61
-1
lines changed

1 file changed

+61
-1
lines changed

codeflare/tests/test_and.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,66 @@
11
import pytest
22
import ray
3+
import pandas as pd
4+
import numpy as np
5+
from sklearn.preprocessing import StandardScaler, MinMaxScaler
6+
import codeflare.pipelines.Datamodel as dm
7+
import codeflare.pipelines.Runtime as rt
38
from codeflare.pipelines.Datamodel import Xy
49
from codeflare.pipelines.Datamodel import XYRef
5-
import codeflare.pipelines.Datamodel as dm
10+
from codeflare.pipelines.Runtime import ExecutionType
11+
12+
class FeatureUnion(dm.AndTransform):
13+
def __init__(self):
14+
pass
15+
16+
def transform(self, xy_list):
17+
X_list = []
18+
y_list = []
19+
20+
for xy in xy_list:
21+
X_list.append(xy.get_x())
22+
X_concat = np.concatenate(X_list, axis=0)
23+
24+
return Xy(X_concat, None)
25+
26+
def test_and():
27+
28+
ray.shutdown()
29+
ray.init()
30+
31+
## prepare the data
32+
X = np.random.randint(0,100,size=(10000, 4))
33+
34+
X_ref = ray.put(X)
35+
y_ref = ray.put(None)
36+
37+
Xy_ref = XYRef(X_ref, y_ref)
38+
Xy_ref_ptr = ray.put(Xy_ref)
39+
Xy_ref_ptrs = [Xy_ref_ptr]
40+
41+
## initialize codeflare pipeline by first creating the nodes
42+
pipeline = dm.Pipeline()
43+
node_a = dm.EstimatorNode('a', MinMaxScaler())
44+
node_b = dm.EstimatorNode('b', StandardScaler())
45+
node_c = dm.AndNode('c', FeatureUnion())
46+
47+
## codeflare nodes are then connected by edges
48+
pipeline.add_edge(node_a, node_b)
49+
pipeline.add_edge(node_a, node_c)
50+
51+
in_args={node_a: Xy_ref_ptrs, node_b: Xy_ref_ptrs}
52+
## execute the codeflare pipeline
53+
out_args = rt.execute_pipeline(pipeline, ExecutionType.FIT, in_args)
54+
55+
## retrieve node c
56+
out_Xyrefs = ray.get(out_args[node_c])
57+
for out_xyref in out_Xyrefs:
58+
x = ray.get(out_xyref.get_Xref())
59+
and_func = ray.get(out_xyref.get_currnoderef()).get_and_func()
60+
print(x)
61+
62+
ray.shutdown()
63+
664

65+
if __name__ == "__main__":
66+
sys.exit(pytest.main(["-v", __file__]))

0 commit comments

Comments
 (0)