Skip to content
42 changes: 24 additions & 18 deletions tfx/dsl/components/base/base_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from tfx.dsl.io import fileio
from tfx.orchestration import data_types
from tfx.orchestration import metadata
from tfx.types import channel_utils
from tfx.types import artifact_utils, channel_utils


def _generate_output_uri(base_output_dir: str,
Expand All @@ -41,8 +41,8 @@ def _generate_output_uri(base_output_dir: str,
def _prepare_output_paths(artifact: types.Artifact):
"""Create output directories for output artifact."""
if fileio.exists(artifact.uri):
msg = 'Output artifact uri %s already exists' % artifact.uri
absl.logging.warning(msg)
absl.logging.warning(
'Output artifact uri %s already exists', artifact.uri)
# TODO(b/158689199): We currently simply return as a short-term workaround
# to unblock execution retires. A comprehensive solution to guarantee
# idempotent executions is needed.
Expand Down Expand Up @@ -85,12 +85,7 @@ def verify_input_artifacts(
Raises:
RuntimeError: if any input as an empty or non-existing uri.
"""
for single_artifacts_list in artifacts_dict.values():
for artifact in single_artifacts_list:
if not artifact.uri:
raise RuntimeError('Artifact %s does not have uri' % artifact)
if not fileio.exists(artifact.uri):
raise RuntimeError('Artifact uri %s is missing' % artifact.uri)
artifact_utils.verify_artifacts(artifacts_dict)

def _log_properties(self, input_dict: Dict[str, List[types.Artifact]],
output_dict: Dict[str, List[types.Artifact]],
Expand Down Expand Up @@ -144,12 +139,12 @@ def resolve_input_artifacts(
# Note: when not initialized, artifact.uri is '' and artifact.id is
# 0.
if not artifact.uri or not artifact.id:
raise ValueError((
'Unresolved input channel %r for input %r was passed in '
'interactive mode. When running in interactive mode, upstream '
'components must first be run with '
'`interactive_context.run(component)` before their outputs can '
'be used in downstream components.') % (artifact, name))
raise ValueError(
f'Unresolved input channel {repr(artifact)} for input '
f'{repr(name)} was passed in interactive mode. When running '
'in interactive mode, upstream components must first be run '
'with `interactive_context.run(component)` before their '
'outputs can be used in downstream components.')
artifacts_by_id.update({a.id: a for a in artifacts})
else:
artifacts = self._metadata_handler.search_artifacts(
Expand Down Expand Up @@ -292,7 +287,18 @@ def pre_execution(
exec_properties=exec_properties,
pipeline_info=pipeline_info,
component_info=component_info)
if output_artifacts is not None:

# Check that cached output artifacts will actually be considered a cache hit by downstream
# components
if output_artifacts is not None:
try:
artifact_utils.verify_artifacts(output_artifacts)
use_cached_results = True
except RuntimeError:
absl.logging.debug(
'Cached results found but could not be verified to still exist')

if use_cached_results:
# If cache should be used, updates execution to reflect that. Note that
# with this update, publisher should / will be skipped.
self._metadata_handler.update_execution(
Expand All @@ -301,9 +307,9 @@ def pre_execution(
output_artifacts=output_artifacts,
execution_state=metadata.EXECUTION_STATE_CACHED,
contexts=contexts)
use_cached_results = True
else:
absl.logging.debug('Cached results not found, move on to new execution')
absl.logging.debug(
'Cached results not available, move on to new execution')
# Step 4a. New execution is needed. Prepare output artifacts.
output_artifacts = self._prepare_output_artifacts(
input_artifacts=input_artifacts,
Expand Down
55 changes: 48 additions & 7 deletions tfx/dsl/components/base/base_driver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,18 @@
from ml_metadata.proto import metadata_store_pb2

# Mock value for string artifact.
_STRING_VALUE = u'This is a string'
_STRING_VALUE = 'This is a string'

# Mock byte value for string artifact.
_BYTE_VALUE = b'This is a string'


def fake_read(self):
"""Mock read method for ValueArtifact."""
if not self._has_value:
self._has_value = True
self._value = self.decode(_BYTE_VALUE)
return self._value
if not self._has_value: # pylint: disable=protected-access
self._has_value = True # pylint: disable=protected-access
self._value = self.decode(_BYTE_VALUE) # pylint: disable=protected-access
return self._value # pylint: disable=protected-access


class _InputArtifact(types.Artifact):
Expand Down Expand Up @@ -104,7 +104,7 @@ def setUp(self):
@mock.patch(
'tfx.dsl.components.base.base_driver.BaseDriver.verify_input_artifacts')
@mock.patch.object(types.ValueArtifact, 'read', fake_read)
def testPreExecutionNewExecution(self, mock_verify_input_artifacts_fn):
def testPreExecutionNewExecution(self, _):
self._mock_metadata.search_artifacts.return_value = list(
self._input_dict['input_string'].get())
self._mock_metadata.register_execution.side_effect = [self._execution]
Expand Down Expand Up @@ -178,7 +178,8 @@ def testResolveInputArtifacts(self):
@mock.patch(
'tfx.dsl.components.base.base_driver.BaseDriver.verify_input_artifacts')
@mock.patch.object(types.ValueArtifact, 'read', fake_read)
def testPreExecutionCached(self, mock_verify_input_artifacts_fn):
def testPreExecutionCached(self, _):
"""With cache enabled, if cached output artifacts are found, execution decision is to use cache"""
self._mock_metadata.search_artifacts.return_value = list(
self._input_dict['input_string'].get())
self._mock_metadata.register_run_context_if_not_exists.side_effect = [
Expand All @@ -204,6 +205,46 @@ def testPreExecutionCached(self, mock_verify_input_artifacts_fn):
self.assertCountEqual(execution_decision.output_dict,
self._output_artifacts)

@mock.patch(
'tfx.dsl.components.base.base_driver.artifact_utils.verify_artifacts'
)
@mock.patch(
'tfx.dsl.components.base.base_driver.BaseDriver.verify_input_artifacts')
@mock.patch.object(types.ValueArtifact, 'read', fake_read)
def testPreExecutionCachedMissing(
self, _, mock_artifact_utils_verify_artifacts_fn
):
"""With cache enabled, if cached output artifacts are found but are missing, execution decision is to not use cache"""

# mock such that the output artifacts as pulled from cache are not present
mock_artifact_utils_verify_artifacts_fn.side_effect = RuntimeError()

self._mock_metadata.search_artifacts.return_value = list(
self._input_dict['input_string'].get())
self._mock_metadata.register_run_context_if_not_exists.side_effect = [
metadata_store_pb2.Context()
]
self._mock_metadata.register_execution.side_effect = [self._execution]
self._mock_metadata.get_cached_outputs.side_effect = [
self._output_artifacts
]

driver = base_driver.BaseDriver(metadata_handler=self._mock_metadata)
execution_decision = driver.pre_execution(
input_dict=self._input_dict,
output_dict=self._output_dict,
exec_properties=self._exec_properties,
driver_args=self._driver_args,
pipeline_info=self._pipeline_info,
component_info=self._component_info)

self.assertFalse(execution_decision.use_cached_results)
self.assertEqual(execution_decision.execution_id, self._execution_id)
self.assertCountEqual(execution_decision.exec_properties,
self._exec_properties)
self.assertCountEqual(execution_decision.output_dict,
self._output_artifacts)

def testVerifyInputArtifactsOk(self):
driver = base_driver.BaseDriver(metadata_handler=self._mock_metadata)
driver.verify_input_artifacts(self._input_artifacts)
Expand Down
8 changes: 6 additions & 2 deletions tfx/orchestration/portable/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,15 @@ def get_cached_outputs(
cached_executions)

# Defensively traverses candidate executions and returns once we find an
# execution with valid outputs.
# execution with valid outputs that can be confirmed to still exist.
for execution in cached_executions:
cached_output_artifacts = _get_outputs_of_execution(metadata_handler,
execution.id)
if cached_output_artifacts is not None:
return cached_output_artifacts
try:
artifact_utils.verify_artifacts(cached_output_artifacts)
return cached_output_artifacts
except RuntimeError:
pass

return None
18 changes: 15 additions & 3 deletions tfx/orchestration/portable/cache_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""Tests for tfx.orchestration.portable.cache_utils."""
import os
import tensorflow as tf
from unittest import mock

from tfx.dsl.io import fileio
from tfx.orchestration import metadata
Expand All @@ -24,10 +25,10 @@
from tfx.proto.orchestration import pipeline_pb2
from tfx.types import standard_artifacts
from tfx.utils import test_case_utils

from google.protobuf import text_format
from ml_metadata.proto import metadata_store_pb2


class CacheUtilsTest(test_case_utils.TfxTest):

def setUp(self):
Expand Down Expand Up @@ -166,7 +167,9 @@ def testGetCacheContextTwiceDifferentExecutorSpec(self):
# Different executor spec will result in new cache context.
self.assertLen(m.store.get_contexts(), 2)

def testGetCachedOutputArtifacts(self):
@mock.patch(
"tfx.orchestration.portable.cache_utils.artifact_utils.verify_artifacts")
def testGetCachedOutputArtifacts(self, mock_verify_artifacts):
# Output artifacts that will be used by the first execution with the same
# cache key.
output_model_one = standard_artifacts.Model()
Expand All @@ -188,10 +191,12 @@ def testGetCachedOutputArtifacts(self):
with metadata.Metadata(connection_config=self._connection_config) as m:
cache_context = context_lib.register_context_if_not_exists(
m, context_lib.CONTEXT_TYPE_EXECUTION_CACHE, 'cache_key')
cached_output = cache_utils.get_cached_outputs(m, cache_context)

# No succeed execution is associate with this context yet, so the cached
# output is None
cached_output = cache_utils.get_cached_outputs(m, cache_context)
self.assertIsNone(cached_output)

execution_one = execution_publish_utils.register_execution(
m, metadata_store_pb2.ExecutionType(name='my_type'), [cache_context])
execution_publish_utils.publish_succeeded_execution(
Expand All @@ -210,6 +215,7 @@ def testGetCachedOutputArtifacts(self):
output_models_key: [output_model_three, output_model_four],
output_examples_key: [output_example_two]
})

# The cached output got should be the artifacts produced by the most
# recent execution under the given cache context.
cached_output = cache_utils.get_cached_outputs(m, cache_context)
Expand All @@ -235,6 +241,12 @@ def testGetCachedOutputArtifacts(self):
'create_time_since_epoch', 'last_update_time_since_epoch'
])

# There should again be no cached outputs if the artifacts cannot be
# verified as still existing
mock_verify_artifacts.side_effect = RuntimeError()
cached_output = cache_utils.get_cached_outputs(m, cache_context)
self.assertIsNone(cached_output)

def testGetCachedOutputArtifactsForNodesWithNoOuput(self):
with metadata.Metadata(connection_config=self._connection_config) as m:
cache_context = context_lib.register_context_if_not_exists(
Expand Down
71 changes: 50 additions & 21 deletions tfx/types/artifact_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
import json
import os
import re
from typing import Dict, List, Optional, Type
from typing import Dict, List, Optional, Type, Union
from absl import logging
from packaging import version

from tfx.dsl.io import fileio
from tfx.types.artifact import _ArtifactType
from tfx.types.artifact import Artifact
from tfx.types.value_artifact import _ValueArtifactType
Expand Down Expand Up @@ -73,8 +74,8 @@ def get_single_instance(artifact_list: List[Artifact]) -> Artifact:
ValueError: If length of artifact_list is not one.
"""
if len(artifact_list) != 1:
raise ValueError('expected list length of one but got {}'.format(
len(artifact_list)))
raise ValueError(
f'expected list length of one but got {len(artifact_list)}')
return artifact_list[0]


Expand Down Expand Up @@ -119,14 +120,12 @@ def is_artifact_version_older_than(artifact: Artifact,
# Artifact without version.
return True

if (version.parse(
# Artifact with old version
return bool(
version.parse(
artifact.get_string_custom_property(
ARTIFACT_TFX_VERSION_CUSTOM_PROPERTY_KEY)) <
version.parse(artifact_version)):
# Artifact with old version.
return True
else:
return False
ARTIFACT_TFX_VERSION_CUSTOM_PROPERTY_KEY)) <
version.parse(artifact_version))


def get_split_uris(artifact_list: List[Artifact], split: str) -> List[str]:
Expand Down Expand Up @@ -154,8 +153,8 @@ def get_split_uris(artifact_list: List[Artifact], split: str) -> List[str]:
else:
result.append(os.path.join(artifact.uri, f'Split-{split}'))
if len(result) != len(artifact_list):
raise ValueError('Split does not exist over all example artifacts: %s' %
split)
raise ValueError(
f'Split does not exist over all example artifacts: {split}')
return result


Expand All @@ -175,8 +174,8 @@ def get_split_uri(artifact_list: List[Artifact], split: str) -> str:
artifact_split_uris = get_split_uris(artifact_list, split)
if len(artifact_split_uris) != 1:
raise ValueError(
('Expected exactly one artifact with split %r, but found matching '
'artifacts %s.') % (split, artifact_split_uris))
f'Expected exactly one artifact with split {repr(split)}, but found '
f'matching artifacts {artifact_split_uris}.')
return artifact_split_uris[0]


Expand All @@ -198,9 +197,9 @@ def encode_split_names(splits: List[str]) -> str:
# TODO(ccy): Disallow empty split names once the importer removes split as
# a property for all artifacts.
raise ValueError(
('Split names are expected to be alphanumeric (allowing dashes and '
'underscores, provided they are not the first character); got %r '
'instead.') % (split,))
'Split names are expected to be alphanumeric (allowing dashes and '
f'underscores, provided they are not the first character); got {repr(split)} '
'instead.')
rewritten_splits.append(split)
return json.dumps(rewritten_splits)

Expand Down Expand Up @@ -303,16 +302,46 @@ def deserialize_artifact(
# Validate inputs.
if not isinstance(artifact_type, metadata_store_pb2.ArtifactType):
raise ValueError(
('Expected metadata_store_pb2.ArtifactType for artifact_type, got %s '
'instead') % (artifact_type,))
'Expected metadata_store_pb2.ArtifactType for artifact_type, got '
f'{artifact_type} instead')
if artifact and not isinstance(artifact, metadata_store_pb2.Artifact):
raise ValueError(
('Expected metadata_store_pb2.Artifact for artifact, got %s '
'instead') % (artifact,))
f'Expected metadata_store_pb2.Artifact for artifact, got {artifact} '
'instead')

# Get the artifact's class and construct the Artifact object.
artifact_cls = get_artifact_type_class(artifact_type)
result = artifact_cls()
result.artifact_type.CopyFrom(artifact_type)
result.set_mlmd_artifact(artifact or metadata_store_pb2.Artifact())
return result


def verify_artifacts(
artifacts: Union[Dict[str, List[Artifact]], List[Artifact], Artifact]
) -> bool:
"""Check that all artifacts have uri and exist at that uri.

Args:
artifacts: artifacts dict (key -> types.Artifact), single artifact
list, or artifact instance.

Returns:
whether all artifacts have uri and exist at that uri
"""
if isinstance(artifacts, Artifact):
artifact_list = [artifacts]
elif isinstance(artifacts, list):
artifact_list = artifacts
elif isinstance(artifacts, dict):
artifact_list = [
a for artifact_list in artifacts.values() for a in artifact_list
]
else:
raise TypeError

for artifact_instance in artifact_list:
if not artifact_instance.uri:
raise RuntimeError(f'Artifact {artifact_instance} does not have uri')
if not fileio.exists(artifact_instance.uri):
raise RuntimeError(f'Artifact uri {artifact_instance.uri} is missing')
Loading