Skip to content

Commit

Permalink
Optionally reuse artifacts that do not affect consistent execution of…
Browse files Browse the repository at this point in the history
… downstream nodes in partial run.

PiperOrigin-RevId: 424263054
  • Loading branch information
tfx-copybara committed Jan 26, 2022
1 parent 93c5ff4 commit 41666c4
Show file tree
Hide file tree
Showing 4 changed files with 179 additions and 57 deletions.
16 changes: 11 additions & 5 deletions tfx/orchestration/experimental/core/pipeline_ops_test.py
Expand Up @@ -128,8 +128,12 @@ def test_resume_pipeline(self, mock_snapshot):
with self._mlmd_connection as m:
pipeline = _test_pipeline('test_pipeline', pipeline_pb2.Pipeline.SYNC)
pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)
pipeline.nodes.add().pipeline_node.node_info.id = 'ExampleGen'
pipeline.nodes.add().pipeline_node.node_info.id = 'Trainer'
node_example_gen = pipeline.nodes.add().pipeline_node
node_example_gen.node_info.id = 'ExampleGen'
node_example_gen.downstream_nodes.extend(['Trainer'])
node_trainer = pipeline.nodes.add().pipeline_node
node_trainer.node_info.id = 'Trainer'
node_trainer.upstream_nodes.extend(['ExampleGen'])

# Error if attempt to resume the pipeline when there is no previous run.
with self.assertRaises(status_lib.StatusNotOkError) as exception_context:
Expand Down Expand Up @@ -167,9 +171,11 @@ def test_resume_pipeline(self, mock_snapshot):
expected_pipeline.runtime_spec.snapshot_settings.latest_pipeline_run_strategy.SetInParent(
)
expected_pipeline.nodes[
0].pipeline_node.execution_options.skip.reuse_artifacts = True
0].pipeline_node.execution_options.skip.reuse_artifacts_mode = pipeline_pb2.NodeExecutionOptions.Skip.REQUIRED
expected_pipeline.nodes[
1].pipeline_node.execution_options.run.perform_snapshot = True
expected_pipeline.nodes[
1].pipeline_node.execution_options.run.depends_on_snapshot = True
with pipeline_ops.resume_pipeline(m, pipeline) as pipeline_state_run2:
self.assertEqual(expected_pipeline, pipeline_state_run2.pipeline)
pipeline_state_run2.is_active()
Expand Down Expand Up @@ -207,7 +213,7 @@ def test_initiate_pipeline_start_with_partial_run(self, mock_snapshot):
expected_pipeline.runtime_spec.snapshot_settings.latest_pipeline_run_strategy.SetInParent(
)
expected_pipeline.nodes[
0].pipeline_node.execution_options.skip.reuse_artifacts = True
0].pipeline_node.execution_options.skip.reuse_artifacts_mode = pipeline_pb2.NodeExecutionOptions.Skip.REQUIRED
expected_pipeline.nodes[
1].pipeline_node.execution_options.run.perform_snapshot = True
expected_pipeline.nodes[
Expand Down Expand Up @@ -247,7 +253,7 @@ def test_initiate_pipeline_start_with_partial_run_default_to_nodes(
expected_pipeline.runtime_spec.snapshot_settings.latest_pipeline_run_strategy.SetInParent(
)
expected_pipeline.nodes[
0].pipeline_node.execution_options.skip.reuse_artifacts = True
0].pipeline_node.execution_options.skip.reuse_artifacts_mode = pipeline_pb2.NodeExecutionOptions.Skip.REQUIRED
expected_pipeline.nodes[
1].pipeline_node.execution_options.run.perform_snapshot = True
expected_pipeline.nodes[
Expand Down
87 changes: 65 additions & 22 deletions tfx/orchestration/portable/partial_run_utils.py
Expand Up @@ -15,7 +15,7 @@

import collections
import enum
from typing import Collection, List, Mapping, Optional, Set
from typing import Collection, List, Mapping, Optional, Set, Tuple

from absl import logging
from tfx.dsl.compiler import compiler_utils
Expand All @@ -31,6 +31,8 @@

_default_snapshot_settings = pipeline_pb2.SnapshotSettings()
_default_snapshot_settings.latest_pipeline_run_strategy.SetInParent()
_REUSE_ARTIFACT_REQUIRED = pipeline_pb2.NodeExecutionOptions.Skip.REQUIRED
_REUSE_ARTIFACT_OPTIONAL = pipeline_pb2.NodeExecutionOptions.Skip.OPTIONAL


def mark_pipeline(
Expand Down Expand Up @@ -94,12 +96,13 @@ def mark_pipeline(
nodes_to_run = _compute_nodes_to_run(node_map, from_node_ids, to_node_ids,
skip_node_ids)

nodes_to_reuse = _compute_nodes_to_reuse(node_map, nodes_to_run)
nodes_required_to_reuse, nodes_to_reuse = _compute_nodes_to_reuse(
node_map, nodes_to_run)
nodes_requiring_snapshot = _compute_nodes_requiring_snapshot(
node_map, nodes_to_run, nodes_to_reuse)
snapshot_node = _pick_snapshot_node(node_map, nodes_to_run, nodes_to_reuse)
_mark_nodes(node_map, nodes_to_run, nodes_to_reuse, nodes_requiring_snapshot,
snapshot_node)
_mark_nodes(node_map, nodes_to_run, nodes_required_to_reuse, nodes_to_reuse,
nodes_requiring_snapshot, snapshot_node)
pipeline.runtime_spec.snapshot_settings.CopyFrom(snapshot_settings)
return pipeline

Expand All @@ -122,7 +125,7 @@ def snapshot(mlmd_handle: metadata.Metadata,
"""
# Avoid unnecessary snapshotting step if no node needs to reuse any artifacts.
if not any(
_should_reuse_artifact(node.pipeline_node.execution_options)
_should_attempt_to_reuse_artifact(node.pipeline_node.execution_options)
for node in pipeline.nodes):
return

Expand Down Expand Up @@ -155,17 +158,24 @@ def _pick_snapshot_node(node_map: Mapping[str, pipeline_pb2.PipelineNode],


def _mark_nodes(node_map: Mapping[str, pipeline_pb2.PipelineNode],
nodes_to_run: Set[str], nodes_to_reuse: Set[str],
nodes_requiring_snapshot: Set[str],
nodes_to_run: Set[str], nodes_required_to_reuse: Set[str],
nodes_to_reuse: Set[str], nodes_requiring_snapshot: Set[str],
snapshot_node: Optional[str]):
"""Mark nodes."""
for node_id, node in node_map.items(): # assumes topological order
if node_id in nodes_to_run:
node.execution_options.run.perform_snapshot = (node_id == snapshot_node)
node.execution_options.run.depends_on_snapshot = (
node_id in nodes_requiring_snapshot)
elif node_id in nodes_required_to_reuse:
node.execution_options.skip.reuse_artifacts_mode = (
_REUSE_ARTIFACT_REQUIRED)
elif node_id in nodes_to_reuse:
node.execution_options.skip.reuse_artifacts_mode = (
_REUSE_ARTIFACT_OPTIONAL)
else:
node.execution_options.skip.reuse_artifacts = (node_id in nodes_to_reuse)
node.execution_options.skip.reuse_artifacts_mode = (
pipeline_pb2.NodeExecutionOptions.Skip.NEVER)


class _Direction(enum.Enum):
Expand Down Expand Up @@ -330,19 +340,30 @@ def _compute_nodes_to_run(

def _compute_nodes_to_reuse(
node_map: Mapping[str, pipeline_pb2.PipelineNode],
nodes_to_run: Set[str],
) -> Set[str]:
"""Returns the set of node ids whose output artifacts are to be reused."""
nodes_to_run: Set[str]) -> Tuple[Set[str], Set[str]]:
"""Returns the set of node ids whose output artifacts are to be reused.
Only upstream nodes of nodes_to_run are required to be reused to reflect
correct lineage.
Args:
node_map: Mapping of node_id to nodes.
nodes_to_run: The set of nodes to run.
Returns:
Set of node ids required to be reused.
"""
exclusion_set = _traverse(
node_map, _Direction.DOWNSTREAM, start_nodes=nodes_to_run)
return set(node_map.keys()) - exclusion_set
nodes_required_to_reuse = _traverse(
node_map, _Direction.UPSTREAM, start_nodes=nodes_to_run) - exclusion_set
nodes_to_reuse = set(node_map.keys()) - exclusion_set
return nodes_required_to_reuse, nodes_to_reuse


def _compute_nodes_requiring_snapshot(
node_map: Mapping[str, pipeline_pb2.PipelineNode],
nodes_to_run: Set[str],
nodes_to_reuse: Set[str],
) -> Set[str]:
def _compute_nodes_requiring_snapshot(node_map: Mapping[
str, pipeline_pb2.PipelineNode], nodes_to_run: Set[str],
nodes_to_reuse: Set[str]) -> Set[str]:
"""Returns the set of nodes to run that depend on a node to reuse."""
result = set()
for node_id, node in node_map.items():
Expand Down Expand Up @@ -401,10 +422,19 @@ def _get_validated_new_run_id(pipeline: pipeline_pb2.Pipeline,
return str(inferred_new_run_id or new_run_id)


def _should_reuse_artifact(
def _should_attempt_to_reuse_artifact(
execution_options: pipeline_pb2.NodeExecutionOptions):
return execution_options.HasField('skip') and (
execution_options.skip.reuse_artifacts or
execution_options.skip.reuse_artifacts_mode == _REUSE_ARTIFACT_OPTIONAL or
execution_options.skip.reuse_artifacts_mode == _REUSE_ARTIFACT_REQUIRED)


def _reuse_artifact_required(
execution_options: pipeline_pb2.NodeExecutionOptions):
return (execution_options.HasField('skip') and
execution_options.skip.reuse_artifacts)
return execution_options.HasField('skip') and (
execution_options.skip.reuse_artifacts or
execution_options.skip.reuse_artifacts_mode == _REUSE_ARTIFACT_REQUIRED)


def _reuse_pipeline_run_artifacts(metadata_handler: metadata.Metadata,
Expand Down Expand Up @@ -449,9 +479,22 @@ def _reuse_pipeline_run_artifacts(metadata_handler: metadata.Metadata,
'base_run_id not provided. '
'Default to latest pipeline run: %s', base_run_id)
for node in marked_pipeline.nodes:
if _should_reuse_artifact(node.pipeline_node.execution_options):
if _should_attempt_to_reuse_artifact(node.pipeline_node.execution_options):
node_id = node.pipeline_node.node_info.id
artifact_recycler.reuse_node_outputs(node_id, base_run_id)
try:
artifact_recycler.reuse_node_outputs(node_id, base_run_id)
except Exception as err: # pylint: disable=broad-except
if _reuse_artifact_required(node.pipeline_node.execution_options):
# Raise error only if failed to reuse artifacts required.
raise
err_str = str(err)
if 'No previous successful execution' in err_str or 'node context' in err_str:
# This is mostly due to no previous execution of the node, so the
# error is safe to suppress.
logging.info(err_str)
else:
logging.warning('Failed to reuse artifacts for node %s. Due to %s',
node_id, err_str)
artifact_recycler.put_parent_context(base_run_id)


Expand Down

0 comments on commit 41666c4

Please sign in to comment.