Skip to content

Commit a0195ea

Browse files
committed
test codeflare pipeline with or node
1 parent 9af9781 commit a0195ea

File tree

1 file changed

+43
-23
lines changed

1 file changed

+43
-23
lines changed

codeflare/tests/test_or.py

Lines changed: 43 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,74 @@
11
import pytest
22
import ray
33
import pandas as pd
4+
import numpy as np
5+
from sklearn.compose import ColumnTransformer
46
from sklearn.model_selection import train_test_split
57
from sklearn.pipeline import Pipeline
6-
from sklearn.impute import SimpleImputer
7-
from sklearn.preprocessing import StandardScaler, OneHotEncoder
8-
from hercules.Datamodel import Xy
9-
from hercules.Datamodel import XYRef
10-
import hercules.Datamodel as dm
11-
import hercules.RuntimeNew as rt
12-
from hercules.RuntimeNew import ExecutionType
8+
from sklearn.preprocessing import StandardScaler
9+
from sklearn.tree import DecisionTreeClassifier
10+
import codeflare.pipelines.Datamodel as dm
11+
import codeflare.pipelines.Runtime as rt
12+
from codeflare.pipelines.Datamodel import Xy
13+
from codeflare.pipelines.Datamodel import XYRef
14+
from codeflare.pipelines.Runtime import ExecutionType
1315

1416
def test_or():
1517

18+
ray.shutdown()
1619
ray.init()
17-
18-
train = pd.read_csv('../resources/data/train_ctrUa4K.csv')
19-
test = pd.read_csv('../resources/data/test_lAUu6dG.csv')
2020

21-
X = train.drop('Loan_Status', axis=1)
22-
y = train['Loan_Status']
21+
## prepare the data
22+
X = pd.DataFrame(np.random.randint(0,100,size=(10000, 4)), columns=list('ABCD'))
23+
y = pd.DataFrame(np.random.randint(0,2,size=(10000, 1)), columns=['Label'])
24+
25+
numeric_features = X.select_dtypes(include=['int64']).columns
26+
numeric_transformer = Pipeline(steps=[
27+
('scaler', StandardScaler())])
28+
29+
## set up preprocessor as StandardScaler
30+
preprocessor = ColumnTransformer(
31+
transformers=[
32+
('num', numeric_transformer, numeric_features),
33+
])
2334

2435
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
2536

2637
X_ref = ray.put(X_train)
2738
y_ref = ray.put(y_train)
2839

2940
Xy_ref = XYRef(X_ref, y_ref)
30-
Xy_ref_list = [Xy_ref]
41+
Xy_ref_ptr = ray.put(Xy_ref)
42+
Xy_ref_ptrs = [Xy_ref_ptr]
43+
44+
## create two decision tree classifiers with different depth limit
45+
c_a = DecisionTreeClassifier(max_depth=3)
46+
c_b = DecisionTreeClassifier(max_depth=5)
3147

48+
## initialize codeflare pipeline by first creating the nodes
3249
pipeline = dm.Pipeline()
33-
node_a = dm.OrNode('preprocess', preprocessor)
34-
node_b = dm.OrNode('c_a', c_a)
35-
node_c = dm.OrNode('c_b', c_b)
36-
50+
node_a = dm.EstimatorNode('preprocess', preprocessor)
51+
node_b = dm.EstimatorNode('c_a', c_a)
52+
node_c = dm.EstimatorNode('c_b', c_b)
53+
54+
## codeflare nodes are then connected by edges
3755
pipeline.add_edge(node_a, node_b)
3856
pipeline.add_edge(node_a, node_c)
3957

40-
in_args={node_a: Xy_ref_list}
58+
in_args={node_a: Xy_ref_ptrs}
59+
## execute the codeflare pipeline
4160
out_args = rt.execute_pipeline(pipeline, ExecutionType.FIT, in_args)
4261

62+
## retrieve node b
4363
node_b_out_args = ray.get(out_args[node_b])
44-
node_c_out_args = ray.get(out_args[node_c])
45-
4664
b_out_xyref = node_b_out_args[0]
47-
4865
ray.get(b_out_xyref.get_Xref())
66+
b_out_node = ray.get(b_out_xyref.get_currnoderef())
67+
sct_b = b_out_node.get_estimator()
68+
print(sct_b.feature_importances_)
69+
70+
ray.shutdown()
4971

5072

5173
if __name__ == "__main__":
5274
sys.exit(pytest.main(["-v", __file__]))
53-
54-

0 commit comments

Comments
 (0)