Skip to content

Commit b625129

Browse files
committed
catch up to develop
1 parent aaeaf50 commit b625129

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

codeflare/pipelines/Datamodel.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -837,7 +837,9 @@ def get_post_edges(self, node: Node):
837837
:return: Outgoing edges for the node
838838
"""
839839
post_edges = []
840-
post_nodes = self.__post_graph__[node]
840+
post_nodes = []
841+
if node in self.__post_graph__.keys():
842+
post_nodes = self.__post_graph__[node]
841843
# Empty post
842844
if not post_nodes:
843845
post_edges.append(Edge(node, None))

codeflare/pipelines/Runtime.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ class ExecutionType(Enum):
5858
"""
5959
FIT = 0,
6060
PREDICT = 1,
61-
SCORE = 2
61+
SCORE = 2,
62+
TRANSFORM = 3
6263

6364

6465
@ray.remote
@@ -140,7 +141,10 @@ def execute_or_node_remote(node: dm.EstimatorNode, mode: ExecutionType, xy_ref:
140141
res_Xref = ray.put(estimator.transform(X))
141142
result = dm.XYRef(res_Xref, xy_ref.get_yref(), prev_node_ptr, prev_node_ptr, [xy_ref])
142143
return result
143-
144+
elif mode == ExecutionType.TRANSFORM:
145+
res_Xref = ray.put(estimator.fit_transform(X))
146+
result = dm.XYRef(res_Xref, xy_ref.get_yref(), prev_node_ptr, prev_node_ptr, [xy_ref])
147+
return result
144148

145149
def execute_or_node(node, pre_edges, edge_args, post_edges, mode: ExecutionType, is_outputNode):
146150
"""

0 commit comments

Comments
 (0)