Skip to content

Commit cf94b0a

Browse files
committed
updated AndEstimator syntax and implementation
1 parent cf631a1 commit cf94b0a

File tree

3 files changed

+22
-20
lines changed

3 files changed

+22
-20
lines changed

codeflare/pipelines/tests/test_Datamodel.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,32 +7,29 @@
77
from sklearn.pipeline import Pipeline
88
from sklearn.preprocessing import StandardScaler, MinMaxScaler
99
from sklearn.tree import DecisionTreeClassifier
10+
import sklearn.base as base
1011
import codeflare.pipelines.Datamodel as dm
1112
import codeflare.pipelines.Runtime as rt
1213
from codeflare.pipelines.Datamodel import Xy
1314
from codeflare.pipelines.Runtime import ExecutionType
1415

15-
1616
class FeatureUnion(dm.AndEstimator):
1717
def __init__(self):
1818
pass
19-
20-
def fit_transform(self, xy_list: list):
21-
return self.transform(xy_list)
22-
2319
def get_estimator_type(self):
2420
return 'transform'
25-
21+
def clone(self):
22+
return base.clone(self)
23+
def fit_transform(self, xy_list):
24+
return self.transform(xy_list)
2625
def transform(self, xy_list):
2726
X_list = []
28-
y_list = []
29-
27+
y_vec = None
3028
for xy in xy_list:
3129
X_list.append(xy.get_x())
32-
X_concat = np.concatenate(X_list, axis=0)
33-
34-
return Xy(X_concat, None)
35-
30+
y_vec = xy.get_y()
31+
X_concat = np.concatenate(X_list, axis=1)
32+
return Xy(X_concat, y_vec)
3633

3734
class MultibranchTestCase(unittest.TestCase):
3835

codeflare/pipelines/tests/test_and.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import ray
33
import pandas as pd
44
import numpy as np
5+
import sklearn.base as base
56
from sklearn.preprocessing import StandardScaler, MinMaxScaler, MaxAbsScaler, RobustScaler
67
import codeflare.pipelines.Datamodel as dm
78
import codeflare.pipelines.Runtime as rt
@@ -12,19 +13,19 @@
1213
class FeatureUnion(dm.AndEstimator):
1314
def __init__(self):
1415
pass
15-
1616
def get_estimator_type(self):
1717
return 'transform'
18-
18+
def clone(self):
19+
return base.clone(self)
20+
def fit_transform(self, xy_list):
21+
return self.transform(xy_list)
1922
def transform(self, xy_list):
2023
X_list = []
2124
y_vec = None
22-
2325
for xy in xy_list:
2426
X_list.append(xy.get_x())
2527
y_vec = xy.get_y()
2628
X_concat = np.concatenate(X_list, axis=1)
27-
2829
return Xy(X_concat, y_vec)
2930

3031
def test_two_tier_and():

codeflare/pipelines/tests/test_multibranch.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from sklearn.preprocessing import StandardScaler, MinMaxScaler
99
from sklearn.tree import DecisionTreeClassifier
1010
from sklearn.linear_model import LogisticRegression
11+
import sklearn.base as base
1112
import codeflare.pipelines.Datamodel as dm
1213
import codeflare.pipelines.Runtime as rt
1314
from codeflare.pipelines.Datamodel import Xy
@@ -17,17 +18,20 @@
1718
class FeatureUnion(dm.AndEstimator):
1819
def __init__(self):
1920
pass
20-
21+
def get_estimator_type(self):
22+
return 'transform'
23+
def clone(self):
24+
return base.clone(self)
25+
def fit_transform(self, xy_list):
26+
return self.transform(xy_list)
2127
def transform(self, xy_list):
2228
X_list = []
2329
y_vec = None
24-
2530
for xy in xy_list:
2631
X_list.append(xy.get_x())
2732
y_vec = xy.get_y()
2833
X_concat = np.concatenate(X_list, axis=1)
29-
30-
return Xy(X_concat, y_vec.values.ravel())
34+
return Xy(X_concat, y_vec)
3135

3236
def test_multibranch_1():
3337

0 commit comments

Comments
 (0)