From 41666c4e4a0cf02fa313afa0938d16222a654faa Mon Sep 17 00:00:00 2001 From: tfx-team Date: Tue, 25 Jan 2022 22:36:17 -0800 Subject: [PATCH] Optionally reuse artifacts that do not affect consistent execution of downstream nodes in partial run. PiperOrigin-RevId: 424263054 --- .../experimental/core/pipeline_ops_test.py | 16 ++- .../portable/partial_run_utils.py | 87 +++++++++---- .../portable/partial_run_utils_test.py | 116 +++++++++++++----- tfx/proto/orchestration/pipeline.proto | 17 ++- 4 files changed, 179 insertions(+), 57 deletions(-) diff --git a/tfx/orchestration/experimental/core/pipeline_ops_test.py b/tfx/orchestration/experimental/core/pipeline_ops_test.py index 52af4e07de..208acd3d18 100644 --- a/tfx/orchestration/experimental/core/pipeline_ops_test.py +++ b/tfx/orchestration/experimental/core/pipeline_ops_test.py @@ -128,8 +128,12 @@ def test_resume_pipeline(self, mock_snapshot): with self._mlmd_connection as m: pipeline = _test_pipeline('test_pipeline', pipeline_pb2.Pipeline.SYNC) pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline) - pipeline.nodes.add().pipeline_node.node_info.id = 'ExampleGen' - pipeline.nodes.add().pipeline_node.node_info.id = 'Trainer' + node_example_gen = pipeline.nodes.add().pipeline_node + node_example_gen.node_info.id = 'ExampleGen' + node_example_gen.downstream_nodes.extend(['Trainer']) + node_trainer = pipeline.nodes.add().pipeline_node + node_trainer.node_info.id = 'Trainer' + node_trainer.upstream_nodes.extend(['ExampleGen']) # Error if attempt to resume the pipeline when there is no previous run. with self.assertRaises(status_lib.StatusNotOkError) as exception_context: @@ -167,9 +171,11 @@ def test_resume_pipeline(self, mock_snapshot): expected_pipeline.runtime_spec.snapshot_settings.latest_pipeline_run_strategy.SetInParent( ) expected_pipeline.nodes[ - 0].pipeline_node.execution_options.skip.reuse_artifacts = True + 0].pipeline_node.execution_options.skip.reuse_artifacts_mode = pipeline_pb2.NodeExecutionOptions.Skip.REQUIRED expected_pipeline.nodes[ 1].pipeline_node.execution_options.run.perform_snapshot = True + expected_pipeline.nodes[ + 1].pipeline_node.execution_options.run.depends_on_snapshot = True with pipeline_ops.resume_pipeline(m, pipeline) as pipeline_state_run2: self.assertEqual(expected_pipeline, pipeline_state_run2.pipeline) pipeline_state_run2.is_active() @@ -207,7 +213,7 @@ def test_initiate_pipeline_start_with_partial_run(self, mock_snapshot): expected_pipeline.runtime_spec.snapshot_settings.latest_pipeline_run_strategy.SetInParent( ) expected_pipeline.nodes[ - 0].pipeline_node.execution_options.skip.reuse_artifacts = True + 0].pipeline_node.execution_options.skip.reuse_artifacts_mode = pipeline_pb2.NodeExecutionOptions.Skip.REQUIRED expected_pipeline.nodes[ 1].pipeline_node.execution_options.run.perform_snapshot = True expected_pipeline.nodes[ @@ -247,7 +253,7 @@ def test_initiate_pipeline_start_with_partial_run_default_to_nodes( expected_pipeline.runtime_spec.snapshot_settings.latest_pipeline_run_strategy.SetInParent( ) expected_pipeline.nodes[ - 0].pipeline_node.execution_options.skip.reuse_artifacts = True + 0].pipeline_node.execution_options.skip.reuse_artifacts_mode = pipeline_pb2.NodeExecutionOptions.Skip.REQUIRED expected_pipeline.nodes[ 1].pipeline_node.execution_options.run.perform_snapshot = True expected_pipeline.nodes[ diff --git a/tfx/orchestration/portable/partial_run_utils.py b/tfx/orchestration/portable/partial_run_utils.py index a01d9af89b..34a0d5950a 100644 --- a/tfx/orchestration/portable/partial_run_utils.py +++ b/tfx/orchestration/portable/partial_run_utils.py @@ -15,7 +15,7 @@ import collections import enum -from typing import Collection, List, Mapping, Optional, Set +from typing import Collection, List, Mapping, Optional, Set, Tuple from absl import logging from tfx.dsl.compiler import compiler_utils @@ -31,6 +31,8 @@ _default_snapshot_settings = pipeline_pb2.SnapshotSettings() _default_snapshot_settings.latest_pipeline_run_strategy.SetInParent() +_REUSE_ARTIFACT_REQUIRED = pipeline_pb2.NodeExecutionOptions.Skip.REQUIRED +_REUSE_ARTIFACT_OPTIONAL = pipeline_pb2.NodeExecutionOptions.Skip.OPTIONAL def mark_pipeline( @@ -94,12 +96,13 @@ def mark_pipeline( nodes_to_run = _compute_nodes_to_run(node_map, from_node_ids, to_node_ids, skip_node_ids) - nodes_to_reuse = _compute_nodes_to_reuse(node_map, nodes_to_run) + nodes_required_to_reuse, nodes_to_reuse = _compute_nodes_to_reuse( + node_map, nodes_to_run) nodes_requiring_snapshot = _compute_nodes_requiring_snapshot( node_map, nodes_to_run, nodes_to_reuse) snapshot_node = _pick_snapshot_node(node_map, nodes_to_run, nodes_to_reuse) - _mark_nodes(node_map, nodes_to_run, nodes_to_reuse, nodes_requiring_snapshot, - snapshot_node) + _mark_nodes(node_map, nodes_to_run, nodes_required_to_reuse, nodes_to_reuse, + nodes_requiring_snapshot, snapshot_node) pipeline.runtime_spec.snapshot_settings.CopyFrom(snapshot_settings) return pipeline @@ -122,7 +125,7 @@ def snapshot(mlmd_handle: metadata.Metadata, """ # Avoid unnecessary snapshotting step if no node needs to reuse any artifacts. if not any( - _should_reuse_artifact(node.pipeline_node.execution_options) + _should_attempt_to_reuse_artifact(node.pipeline_node.execution_options) for node in pipeline.nodes): return @@ -155,8 +158,8 @@ def _pick_snapshot_node(node_map: Mapping[str, pipeline_pb2.PipelineNode], def _mark_nodes(node_map: Mapping[str, pipeline_pb2.PipelineNode], - nodes_to_run: Set[str], nodes_to_reuse: Set[str], - nodes_requiring_snapshot: Set[str], + nodes_to_run: Set[str], nodes_required_to_reuse: Set[str], + nodes_to_reuse: Set[str], nodes_requiring_snapshot: Set[str], snapshot_node: Optional[str]): """Mark nodes.""" for node_id, node in node_map.items(): # assumes topological order @@ -164,8 +167,15 @@ def _mark_nodes(node_map: Mapping[str, pipeline_pb2.PipelineNode], node.execution_options.run.perform_snapshot = (node_id == snapshot_node) node.execution_options.run.depends_on_snapshot = ( node_id in nodes_requiring_snapshot) + elif node_id in nodes_required_to_reuse: + node.execution_options.skip.reuse_artifacts_mode = ( + _REUSE_ARTIFACT_REQUIRED) + elif node_id in nodes_to_reuse: + node.execution_options.skip.reuse_artifacts_mode = ( + _REUSE_ARTIFACT_OPTIONAL) else: - node.execution_options.skip.reuse_artifacts = (node_id in nodes_to_reuse) + node.execution_options.skip.reuse_artifacts_mode = ( + pipeline_pb2.NodeExecutionOptions.Skip.NEVER) class _Direction(enum.Enum): @@ -330,19 +340,30 @@ def _compute_nodes_to_run( def _compute_nodes_to_reuse( node_map: Mapping[str, pipeline_pb2.PipelineNode], - nodes_to_run: Set[str], -) -> Set[str]: - """Returns the set of node ids whose output artifacts are to be reused.""" + nodes_to_run: Set[str]) -> Tuple[Set[str], Set[str]]: + """Returns the set of node ids whose output artifacts are to be reused. + + Only upstream nodes of nodes_to_run are required to be reused to reflect + correct lineage. + + Args: + node_map: Mapping of node_id to nodes. + nodes_to_run: The set of nodes to run. + + Returns: + Set of node ids required to be reused. + """ exclusion_set = _traverse( node_map, _Direction.DOWNSTREAM, start_nodes=nodes_to_run) - return set(node_map.keys()) - exclusion_set + nodes_required_to_reuse = _traverse( + node_map, _Direction.UPSTREAM, start_nodes=nodes_to_run) - exclusion_set + nodes_to_reuse = set(node_map.keys()) - exclusion_set + return nodes_required_to_reuse, nodes_to_reuse -def _compute_nodes_requiring_snapshot( - node_map: Mapping[str, pipeline_pb2.PipelineNode], - nodes_to_run: Set[str], - nodes_to_reuse: Set[str], -) -> Set[str]: +def _compute_nodes_requiring_snapshot(node_map: Mapping[ + str, pipeline_pb2.PipelineNode], nodes_to_run: Set[str], + nodes_to_reuse: Set[str]) -> Set[str]: """Returns the set of nodes to run that depend on a node to reuse.""" result = set() for node_id, node in node_map.items(): @@ -401,10 +422,19 @@ def _get_validated_new_run_id(pipeline: pipeline_pb2.Pipeline, return str(inferred_new_run_id or new_run_id) -def _should_reuse_artifact( +def _should_attempt_to_reuse_artifact( + execution_options: pipeline_pb2.NodeExecutionOptions): + return execution_options.HasField('skip') and ( + execution_options.skip.reuse_artifacts or + execution_options.skip.reuse_artifacts_mode == _REUSE_ARTIFACT_OPTIONAL or + execution_options.skip.reuse_artifacts_mode == _REUSE_ARTIFACT_REQUIRED) + + +def _reuse_artifact_required( execution_options: pipeline_pb2.NodeExecutionOptions): - return (execution_options.HasField('skip') and - execution_options.skip.reuse_artifacts) + return execution_options.HasField('skip') and ( + execution_options.skip.reuse_artifacts or + execution_options.skip.reuse_artifacts_mode == _REUSE_ARTIFACT_REQUIRED) def _reuse_pipeline_run_artifacts(metadata_handler: metadata.Metadata, @@ -449,9 +479,22 @@ def _reuse_pipeline_run_artifacts(metadata_handler: metadata.Metadata, 'base_run_id not provided. ' 'Default to latest pipeline run: %s', base_run_id) for node in marked_pipeline.nodes: - if _should_reuse_artifact(node.pipeline_node.execution_options): + if _should_attempt_to_reuse_artifact(node.pipeline_node.execution_options): node_id = node.pipeline_node.node_info.id - artifact_recycler.reuse_node_outputs(node_id, base_run_id) + try: + artifact_recycler.reuse_node_outputs(node_id, base_run_id) + except Exception as err: # pylint: disable=broad-except + if _reuse_artifact_required(node.pipeline_node.execution_options): + # Raise error only if failed to reuse artifacts required. + raise + err_str = str(err) + if 'No previous successful execution' in err_str or 'node context' in err_str: + # This is mostly due to no previous execution of the node, so the + # error is safe to suppress. + logging.info(err_str) + else: + logging.warning('Failed to reuse artifacts for node %s. Due to %s', + node_id, err_str) artifact_recycler.put_parent_context(base_run_id) diff --git a/tfx/orchestration/portable/partial_run_utils_test.py b/tfx/orchestration/portable/partial_run_utils_test.py index 83011e4e8b..0380ec0588 100644 --- a/tfx/orchestration/portable/partial_run_utils_test.py +++ b/tfx/orchestration/portable/partial_run_utils_test.py @@ -80,12 +80,15 @@ def _to_input_channel( class MarkPipelineFnTest(parameterized.TestCase, test_case_utils.TfxTest): - def _checkNodeExecutionOptions(self, pipeline: pipeline_pb2.Pipeline, - snapshot_node: Optional[str], - nodes_to_run: Set[str], - nodes_requiring_snapshot: Set[str], - nodes_to_skip: Set[str], - nodes_to_reuse: Set[str]): + def _checkNodeExecutionOptions( + self, + pipeline: pipeline_pb2.Pipeline, + snapshot_node: Optional[str], + nodes_to_run: Set[str], + nodes_requiring_snapshot: Set[str], + nodes_to_skip: Set[str], + nodes_required_to_reuse: Set[str], + nodes_optional_to_reuse: Optional[Set[str]] = None): for node in pipeline.nodes: try: node_id = node.pipeline_node.node_info.id @@ -112,13 +115,19 @@ def _checkNodeExecutionOptions(self, pipeline: pipeline_pb2.Pipeline, self.assertNotIn(node_id, nodes_requiring_snapshot) elif node_id in nodes_to_skip: self.assertEqual(run_or_skip, 'skip') - # assert reuse_artifacts iff node_id in nodes_to_reuse - if node_id in nodes_to_reuse: - self.assertTrue( - node.pipeline_node.execution_options.skip.reuse_artifacts) + # assert reuse_artifacts iff node_id in nodes_required_to_reuse + if node_id in nodes_required_to_reuse: + self.assertEqual( + node.pipeline_node.execution_options.skip.reuse_artifacts_mode, + pipeline_pb2.NodeExecutionOptions.Skip.REQUIRED) + elif nodes_optional_to_reuse and node_id in nodes_optional_to_reuse: + self.assertEqual( + node.pipeline_node.execution_options.skip.reuse_artifacts_mode, + pipeline_pb2.NodeExecutionOptions.Skip.OPTIONAL) else: - self.assertFalse( - node.pipeline_node.execution_options.skip.reuse_artifacts) + self.assertEqual( + node.pipeline_node.execution_options.skip.reuse_artifacts_mode, + pipeline_pb2.NodeExecutionOptions.Skip.NEVER) else: raise ValueError(f'node_id {node_id} appears in neither nodes_to_run ' 'nor nodes_to_skip.') @@ -287,7 +296,7 @@ def testFilterOutSinkNode(self): expected nodes_to_run: [node_a, node_b] expected nodes_requiring_snapshot: [] expected nodes_to_skip: [node_c] - expected nodes_to_reuse: [] + expected nodes_required_to_reuse: [] """ input_pipeline = self._createInputPipeline({ 'a': ['b'], @@ -303,7 +312,7 @@ def testFilterOutSinkNode(self): nodes_to_run=set(['a', 'b']), nodes_requiring_snapshot=set(), nodes_to_skip=set(['c']), - nodes_to_reuse=set()) + nodes_required_to_reuse=set()) def testFilterOutSourceNode(self): """Filter out a node that has no upstream nodes but has downstream nodes. @@ -316,7 +325,7 @@ def testFilterOutSourceNode(self): expected nodes_to_run: [node_b, node_c] expected nodes_requiring_snapshot: [node_b] expected nodes_to_skip: [node_a] - expected nodes_to_reuse: [node_a] + expected nodes_required_to_reuse: [node_a] """ input_pipeline = self._createInputPipeline({ 'a': ['b'], @@ -332,7 +341,7 @@ def testFilterOutSourceNode(self): nodes_to_run=set(['b', 'c']), nodes_requiring_snapshot=set(['b']), nodes_to_skip=set(['a']), - nodes_to_reuse=set(['a'])) + nodes_required_to_reuse=set(['a'])) def testFilterOutSourceNode_triangle(self): """Filter out a source node in a triangle. @@ -347,7 +356,7 @@ def testFilterOutSourceNode_triangle(self): expected nodes_to_run: [node_b, node_c] expected nodes_requiring_snapshot: [node_b, node_c] expected nodes_to_skip: [node_a] - expected nodes_to_reuse: [node_a] + expected nodes_required_to_reuse: [node_a] """ input_pipeline = self._createInputPipeline({ 'a': ['b', 'c'], @@ -363,7 +372,7 @@ def testFilterOutSourceNode_triangle(self): nodes_to_run=set(['b', 'c']), nodes_requiring_snapshot=set(['b', 'c']), nodes_to_skip=set(['a']), - nodes_to_reuse=set(['a'])) + nodes_required_to_reuse=set(['a'])) def testRunMiddleNode(self): """Run only the middle node. @@ -377,7 +386,7 @@ def testRunMiddleNode(self): expected nodes_to_run: [node_c] expected nodes_requiring_snapshot: [node_c] expected nodes_to_skip: [node_a, node_b, node_d] - expected nodes_to_reuse: [node_a, node_b] + expected nodes_required_to_reuse: [node_a, node_b] """ input_pipeline = self._createInputPipeline({ 'a': ['b'], @@ -394,7 +403,7 @@ def testRunMiddleNode(self): nodes_to_run=set(['c']), nodes_requiring_snapshot=set(['c']), nodes_to_skip=set(['a', 'b', 'd']), - nodes_to_reuse=set(['a', 'b'])) + nodes_required_to_reuse=set(['a', 'b'])) def testRunSinkNode(self): """Run only a sink node. @@ -412,7 +421,7 @@ def testRunSinkNode(self): expected nodes_to_run: [node_c] expected nodes_requiring_snapshot: [node_c] expected nodes_to_skip: [node_a, node_b] - expected nodes_to_reuse: [node_a, node_b] + expected nodes_required_to_reuse: [node_a, node_b] """ input_pipeline = self._createInputPipeline({ 'a': ['b'], @@ -427,7 +436,7 @@ def testRunSinkNode(self): nodes_to_run=set(['c']), nodes_requiring_snapshot=set(['c']), nodes_to_skip=set(['a', 'b']), - nodes_to_reuse=set(['a', 'b'])) + nodes_required_to_reuse=set(['a', 'b'])) def testRunSinkNode_triangle(self): """Filter out a source node in a triangle. @@ -441,7 +450,7 @@ def testRunSinkNode_triangle(self): expected nodes_to_run: [node_c] expected nodes_requiring_snapshot: [node_c] expected nodes_to_skip: [node_a, node_b] - expected nodes_to_reuse: [node_a, node_b] + expected nodes_required_to_reuse: [node_a, node_b] """ input_pipeline = self._createInputPipeline({ 'a': ['b', 'c'], @@ -456,7 +465,7 @@ def testRunSinkNode_triangle(self): nodes_to_run=set(['c']), nodes_requiring_snapshot=set(['c']), nodes_to_skip=set(['a', 'b']), - nodes_to_reuse=set(['a', 'b'])) + nodes_required_to_reuse=set(['a', 'b'])) def testRunMiddleNode_twoIndependentDAGs(self): """Run only a middle node in a pipeline with two independent DAGs. @@ -479,7 +488,7 @@ def testRunMiddleNode_twoIndependentDAGs(self): expected nodes_to_run: [node_b2] expected nodes_requiring_snapshot: [node_b2] expected nodes_to_skip: [node_a1, node_b1, node_c1, node_a2, node_c2] - expected nodes_to_reuse: [node_a1, node_b1, node_c1, node_a2] + expected nodes_required_to_reuse: [node_a1, node_b1, node_c1, node_a2] """ input_pipeline = self._createInputPipeline({ 'a1': ['b1'], @@ -498,7 +507,8 @@ def testRunMiddleNode_twoIndependentDAGs(self): nodes_to_run=set(['b2']), nodes_requiring_snapshot=set(['b2']), nodes_to_skip=set(['a1', 'b1', 'c1', 'a2', 'c2']), - nodes_to_reuse=set(['a1', 'b1', 'c1', 'a2'])) + nodes_required_to_reuse=set(['a2']), + nodes_optional_to_reuse=set(['a1', 'b1', 'c1'])) def testReuseableNodes(self): """Node excluded will not be run. @@ -512,7 +522,7 @@ def testReuseableNodes(self): expected nodes_to_run: [node_c] expected nodes_requiring_snapshot: [node_c] expected nodes_to_skip: [node_a, node_b, node_d] - expected nodes_to_reuse: [node_a, node_b] + expected nodes_required_to_reuse: [node_a, node_b] """ input_pipeline = self._createInputPipeline({ 'a': ['b'], @@ -528,7 +538,7 @@ def testReuseableNodes(self): nodes_to_run=set(['b', 'c', 'd']), nodes_requiring_snapshot=set(['b']), nodes_to_skip=set(['a']), - nodes_to_reuse=set(['a'])) + nodes_required_to_reuse=set(['a'])) input_pipeline = self._createInputPipeline({ 'a': ['b'], @@ -544,7 +554,7 @@ def testReuseableNodes(self): nodes_to_run=set(['c', 'd']), nodes_requiring_snapshot=set(['c']), nodes_to_skip=set(['a', 'b']), - nodes_to_reuse=set(['a', 'b'])) + nodes_required_to_reuse=set(['a', 'b'])) # pylint: disable=invalid-name @@ -1543,6 +1553,54 @@ def testReusePipelineArtifacts_inconsistentNewRunId_error(self): m, pipeline_pb_run_2, base_run_id='run_1', new_run_id='run_3') # <-- user error here + def testReusePipelineArtifacts_SeparateBranches(self): + """Tests partial run with separate branches.""" + ############################################################################ + # + # This pipeline consists of *two* separate branches. + # + # --------- ----------- ----------- + # | Load 1 | -----> | AddNum 1 | ----> | Result 1 | # Branch 1 + # --------- ----------- ----------- + # run: \ + # \ ----------- ----------- + # \--> | AddNum 2 | ----> | Result 2 | # Branch 2 + # ----------- ----------- + # + ############################################################################ + # pylint: disable=no-value-for-parameter + load_1 = Load(start_num=1).with_id('load_1') + add_num_1 = AddNum(to_add=1, num=load_1.outputs['num']).with_id('add_num_1') + result_1 = Result(result=add_num_1.outputs['added_num']).with_id('result_1') + add_num_2 = AddNum( + to_add=10, num=load_1.outputs['num']).with_id('add_num_2') + result_2 = Result(result=add_num_2.outputs['added_num']).with_id('result_2') + # On first run, execute branch 1. + pipeline_pb_run_1 = self.make_pipeline( + components=[load_1, add_num_1, result_1, add_num_2, result_2], + run_id='run_1') + partial_run_utils.mark_pipeline( + pipeline_pb_run_1, from_nodes=[load_1.id], to_nodes=[result_1.id]) + beam_dag_runner.BeamDagRunner().run_with_ir(pipeline_pb_run_1) + self.assertResultEqual(pipeline_pb_run_1, [(result_1.id, 2)]) + + # On second run, only execute part of branch 1. Artifact reusing would + # fail for nodes on branch 2, but should not raise any error. + # pylint: disable=no-value-for-parameter + add_num_1_v2 = AddNum( + to_add=5, num=load_1.outputs['num']).with_id('add_num_1') + load_1.remove_downstream_node(add_num_1) # This line is important. + result_1_v2 = Result( + result=add_num_1_v2.outputs['added_num']).with_id('result_1') + # pylint: enable=no-value-for-parameter + pipeline_pb_run_2 = self.make_pipeline( + components=[load_1, add_num_1_v2, result_1_v2, add_num_2, result_2], + run_id='run_2') + partial_run_utils.mark_pipeline( + pipeline_pb_run_2, from_nodes=[add_num_1_v2.id]) + beam_dag_runner.BeamDagRunner().run_with_ir(pipeline_pb_run_2) + self.assertResultEqual(pipeline_pb_run_2, [(result_1_v2.id, 6)]) + if __name__ == '__main__': absltest.main() diff --git a/tfx/proto/orchestration/pipeline.proto b/tfx/proto/orchestration/pipeline.proto index e03fc3984d..ace56f7289 100644 --- a/tfx/proto/orchestration/pipeline.proto +++ b/tfx/proto/orchestration/pipeline.proto @@ -326,10 +326,25 @@ message NodeExecutionOptions { bool depends_on_snapshot = 2; } message Skip { + // Deprecated. Please use reuse_artifacts_mode field instead. // If reuse_artifacts is true, the snapshot operation will make sure that // output artifacts produced by this node in a previous pipeline run will // be made available in this partial run. - bool reuse_artifacts = 1; + bool reuse_artifacts = 1 [deprecated = true]; + enum ReuseArtifactsMode { + UNSPECIFIED = 0; + // The snapshot operation will not reuse any output artifacts for this + // node. + NEVER = 1; + // The snapshot operation will make sure that output artifacts produced by + // this node in a previous pipeline run will be made available in this + // partial run. + REQUIRED = 2; + // The snapshot operation will attempt to reuse output artifacts at + // best effort basis. + OPTIONAL = 3; + } + ReuseArtifactsMode reuse_artifacts_mode = 2; } CachingOptions caching_options = 1; // Attached by platform-level tooling.