Skip to content

Commit aaeaf50

Browse files
committed
catch up to develop
1 parent 53f5542 commit aaeaf50

File tree

1 file changed

+46
-0
lines changed

1 file changed

+46
-0
lines changed
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import pytest
2+
import ray
3+
import pandas as pd
4+
import numpy as np
5+
import sklearn.base as base
6+
from sklearn.preprocessing import MinMaxScaler
7+
import codeflare.pipelines.Datamodel as dm
8+
import codeflare.pipelines.Runtime as rt
9+
from codeflare.pipelines.Datamodel import Xy
10+
from codeflare.pipelines.Datamodel import XYRef
11+
from codeflare.pipelines.Runtime import ExecutionType
12+
13+
def test_singleton():
14+
15+
ray.shutdown()
16+
ray.init()
17+
18+
## prepare the data
19+
X = np.random.randint(0,100,size=(10000, 4))
20+
y = np.random.randint(0,2,size=(10000, 1))
21+
22+
## initialize codeflare pipeline by first creating the nodes
23+
pipeline = dm.Pipeline()
24+
node_a = dm.EstimatorNode('a', MinMaxScaler())
25+
pipeline.add_node(node_a)
26+
27+
pipeline_input = dm.PipelineInput()
28+
xy = dm.Xy(X,y)
29+
pipeline_input.add_xy_arg(node_a, xy)
30+
31+
## execute the codeflare pipeline
32+
pipeline_output = rt.execute_pipeline(pipeline, ExecutionType.TRANSFORM, pipeline_input)
33+
34+
## retrieve node e
35+
node_a_output = pipeline_output.get_xyrefs(node_a)
36+
Xout = ray.get(node_a_output[0].get_Xref())
37+
yout = ray.get(node_a_output[0].get_yref())
38+
39+
assert Xout.shape[0] == 10000
40+
assert yout.shape[0] == 10000
41+
42+
ray.shutdown()
43+
44+
if __name__ == "__main__":
45+
sys.exit(pytest.main(["-v", __file__]))
46+

0 commit comments

Comments
 (0)