Skip to content

Commit 2fd56c5

Browse files
committed
updated api call interfaces
1 parent 908d661 commit 2fd56c5

File tree

3 files changed

+27
-58
lines changed

3 files changed

+27
-58
lines changed

codeflare/tests/test_and.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,7 @@ def test_and():
3030

3131
## prepare the data
3232
X = np.random.randint(0,100,size=(10000, 4))
33-
34-
X_ref = ray.put(X)
35-
y_ref = ray.put(None)
36-
37-
Xy_ref = XYRef(X_ref, y_ref)
38-
Xy_ref_ptr = ray.put(Xy_ref)
39-
Xy_ref_ptrs = [Xy_ref_ptr]
33+
y = np.random.randint(0,2,size=(10000, 1))
4034

4135
## initialize codeflare pipeline by first creating the nodes
4236
pipeline = dm.Pipeline()
@@ -48,19 +42,17 @@ def test_and():
4842
pipeline.add_edge(node_a, node_c)
4943
pipeline.add_edge(node_b, node_c)
5044

51-
in_args={node_a: Xy_ref_ptrs, node_b: Xy_ref_ptrs}
45+
pipeline_input = dm.PipelineInput()
46+
xy = dm.Xy(X,y)
47+
pipeline_input.add_xy_arg(node_a, xy)
48+
pipeline_input.add_xy_arg(node_b, xy)
49+
5250
## execute the codeflare pipeline
53-
out_args = rt.execute_pipeline(pipeline, ExecutionType.FIT, in_args)
51+
pipeline_output = rt.execute_pipeline(pipeline, ExecutionType.FIT, pipeline_input)
5452

5553
## retrieve node c
56-
out_Xyrefs = ray.get(out_args[node_c])
57-
assert out_Xyrefs
58-
59-
for out_xyref in out_Xyrefs:
60-
x = ray.get(out_xyref.get_Xref())
61-
and_func = ray.get(out_xyref.get_currnoderef()).get_and_func()
62-
assert x.any()
63-
print(x)
54+
node_c_output = pipeline_output.get_xyrefs(node_c)
55+
assert node_c_output
6456

6557
ray.shutdown()
6658

codeflare/tests/test_multibranch.py

Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,6 @@ def test_multibranch():
4848

4949
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
5050

51-
X_ref = ray.put(X_train)
52-
y_ref = ray.put(y_train)
53-
54-
Xy_ref = XYRef(X_ref, y_ref)
55-
Xy_ref_ptr = ray.put(Xy_ref)
56-
Xy_ref_ptrs = [Xy_ref_ptr]
57-
5851
## create two decision tree classifiers with different depth limit
5952
c_a = DecisionTreeClassifier(max_depth=3)
6053
c_b = DecisionTreeClassifier(max_depth=5)
@@ -78,23 +71,17 @@ def test_multibranch():
7871
pipeline.add_edge(node_d, node_f)
7972
pipeline.add_edge(node_e, node_f)
8073

81-
in_args={node_a: Xy_ref_ptrs}
74+
pipeline_input = dm.PipelineInput()
75+
xy = dm.Xy(X_train, y_train)
76+
pipeline_input.add_xy_arg(node_a, xy)
77+
8278
## execute the codeflare pipeline
83-
out_args = rt.execute_pipeline(pipeline, ExecutionType.FIT, in_args)
84-
assert out_args
79+
pipeline_output = rt.execute_pipeline(pipeline, ExecutionType.FIT, pipeline_input)
80+
assert pipeline_output
8581

8682
## retrieve node b
87-
node_b_out_args = ray.get(out_args[node_b])
88-
b_out_xyref = node_b_out_args[0]
89-
ray.get(b_out_xyref.get_Xref())
90-
b_out_node = ray.get(b_out_xyref.get_currnoderef())
91-
sct_b = b_out_node.get_estimator()
92-
assert sct_b
93-
print(sct_b.feature_importances_)
94-
95-
## retrieve node f
96-
out_Xyrefs_f = ray.get(out_args[node_f])
97-
assert out_Xyrefs_f
83+
node_b_output = pipeline_output.get_xyrefs(node_b)
84+
assert node_b_output
9885

9986
ray.shutdown()
10087

codeflare/tests/test_or.py

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,7 @@ def test_or():
3333
])
3434

3535
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
36-
37-
X_ref = ray.put(X_train)
38-
y_ref = ray.put(y_train)
39-
40-
Xy_ref = XYRef(X_ref, y_ref)
41-
Xy_ref_ptr = ray.put(Xy_ref)
42-
Xy_ref_ptrs = [Xy_ref_ptr]
43-
36+
4437
## create two decision tree classifiers with different depth limit
4538
c_a = DecisionTreeClassifier(max_depth=3)
4639
c_b = DecisionTreeClassifier(max_depth=5)
@@ -55,19 +48,16 @@ def test_or():
5548
pipeline.add_edge(node_a, node_b)
5649
pipeline.add_edge(node_a, node_c)
5750

58-
in_args={node_a: Xy_ref_ptrs}
59-
## execute the codeflare pipeline
60-
out_args = rt.execute_pipeline(pipeline, ExecutionType.FIT, in_args)
61-
assert out_args
51+
pipeline_input = dm.PipelineInput()
52+
xy = dm.Xy(X_train, y_train)
53+
pipeline_input.add_xy_arg(node_a, xy)
54+
55+
pipeline_output = rt.execute_pipeline(pipeline, ExecutionType.FIT, pipeline_input)
56+
57+
node_b_output = pipeline_output.get_xyrefs(node_b)
58+
node_c_output = pipeline_output.get_xyrefs(node_c)
6259

63-
## retrieve node b
64-
node_b_out_args = ray.get(out_args[node_b])
65-
b_out_xyref = node_b_out_args[0]
66-
ray.get(b_out_xyref.get_Xref())
67-
b_out_node = ray.get(b_out_xyref.get_currnoderef())
68-
sct_b = b_out_node.get_estimator()
69-
assert sct_b
70-
print(sct_b.feature_importances_)
60+
assert node_b_output
7161

7262
ray.shutdown()
7363

0 commit comments

Comments
 (0)