Skip to content

Commit

Permalink
Add tag to ExternalPipelineChannel so we can get artifacts by tags.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 630465771
  • Loading branch information
tfx-copybara committed May 29, 2024
1 parent 4b69689 commit 1560977
Show file tree
Hide file tree
Showing 7 changed files with 605 additions and 17 deletions.
2 changes: 2 additions & 0 deletions tfx/dsl/compiler/compiler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from tfx.dsl.compiler.testdata import conditional_pipeline
from tfx.dsl.compiler.testdata import consumer_pipeline
from tfx.dsl.compiler.testdata import consumer_pipeline_different_project
from tfx.dsl.compiler.testdata import consumer_pipeline_with_tags
from tfx.dsl.compiler.testdata import dynamic_exec_properties_pipeline
from tfx.dsl.compiler.testdata import external_artifacts_pipeline
from tfx.dsl.compiler.testdata import foreach_pipeline
Expand Down Expand Up @@ -143,6 +144,7 @@ def _get_pipeline_ir(self, filename: str) -> pipeline_pb2.Pipeline:
consumer_pipeline,
external_artifacts_pipeline,
consumer_pipeline_different_project,
consumer_pipeline_with_tags,
])
)
def testCompile(
Expand Down
80 changes: 65 additions & 15 deletions tfx/dsl/compiler/node_inputs_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
# limitations under the License.
"""Compiler submodule specialized for NodeInputs."""

from collections.abc import Iterable
from typing import Type, cast
from collections.abc import Iterable, Sequence
import functools
from typing import Optional, Type, cast

from tfx import types
from tfx.dsl.compiler import compiler_context
Expand All @@ -41,6 +42,8 @@

from ml_metadata.proto import metadata_store_pb2

_PropertyPredicate = pipeline_pb2.PropertyPredicate


def _get_tfx_value(value: str) -> pipeline_pb2.Value:
"""Returns a TFX Value containing the provided string."""
Expand Down Expand Up @@ -135,14 +138,24 @@ def compile_op_node(op_node: resolver_op.OpNode):


def _compile_channel_pb_contexts(
context_types_and_names: Iterable[tuple[str, pipeline_pb2.Value]],
# TODO(b/264728226) Can flatten these args to make it more readable.
types_values_and_predicates: Iterable[
tuple[str, pipeline_pb2.Value, Optional[_PropertyPredicate]]
],
result: pipeline_pb2.InputSpec.Channel,
):
"""Adds contexts to the channel."""
for context_type, context_value in context_types_and_names:
for (
context_type,
context_value,
predicate,
) in types_values_and_predicates:
ctx = result.context_queries.add()
ctx.type.name = context_type
ctx.name.CopyFrom(context_value)
if context_value:
ctx.name.CopyFrom(context_value)
if predicate:
ctx.property_predicate.CopyFrom(predicate)


def _compile_channel_pb(
Expand All @@ -157,16 +170,19 @@ def _compile_channel_pb(
result.artifact_query.type.CopyFrom(mlmd_artifact_type)
result.artifact_query.type.ClearField('properties')

contexts_types_and_values = [
(constants.PIPELINE_CONTEXT_TYPE_NAME, _get_tfx_value(pipeline_name))
]
contexts_types_and_values = [(
constants.PIPELINE_CONTEXT_TYPE_NAME,
_get_tfx_value(pipeline_name),
None,
)]
if node_id:
contexts_types_and_values.append(
(
constants.NODE_CONTEXT_TYPE_NAME,
_get_tfx_value(
compiler_utils.node_context_name(pipeline_name, node_id)
),
None,
),
)
_compile_channel_pb_contexts(contexts_types_and_values, result)
Expand All @@ -175,6 +191,37 @@ def _compile_channel_pb(
result.output_key = output_key


def _construct_predicate(
predicate_names_and_values: Sequence[tuple[str, metadata_store_pb2.Value]],
) -> Optional[_PropertyPredicate]:
"""Constructs a PropertyPredicate from a list of name and value pairs."""
if not predicate_names_and_values:
return None

predicates = []
for name, predicate_value in predicate_names_and_values:
predicates.append(
_PropertyPredicate(
value_comparator=_PropertyPredicate.ValueComparator(
property_name=name,
op=_PropertyPredicate.ValueComparator.Op.EQ,
target_value=pipeline_pb2.Value(field_value=predicate_value),
is_custom_property=True,
)
)
)

def _make_and(lhs, rhs):
return _PropertyPredicate(
binary_logical_operator=_PropertyPredicate.BinaryLogicalOperator(
op=_PropertyPredicate.BinaryLogicalOperator.AND, lhs=lhs, rhs=rhs
)
)

if predicates:
return functools.reduce(_make_and, predicates)


def _compile_input_spec(
*,
pipeline_ctx: compiler_context.PipelineContext,
Expand Down Expand Up @@ -206,7 +253,7 @@ def _compile_input_spec(
# from the same resolver function output.
if not hidden:
# Overwrite hidden = False even for already compiled channel, this is
# because we don't know the input should truely be hidden until the
# because we don't know the input should truly be hidden until the
# channel turns out not to be.
result.inputs[input_key].hidden = False
return
Expand Down Expand Up @@ -240,11 +287,15 @@ def _compile_input_spec(
result=result_input_channel,
)

if channel.pipeline_run_id:
if channel.pipeline_run_id or channel.run_context_predicates:
predicate = _construct_predicate(channel.run_context_predicates)
_compile_channel_pb_contexts(
[(
constants.PIPELINE_RUN_CONTEXT_TYPE_NAME,
_get_tfx_value(channel.pipeline_run_id),
_get_tfx_value(
channel.pipeline_run_id if channel.pipeline_run_id else ''
),
predicate,
)],
result_input_channel,
)
Expand Down Expand Up @@ -290,10 +341,9 @@ def _compile_input_spec(
contexts_to_add = []
for context_spec in node_contexts.contexts:
if context_spec.type.name == constants.PIPELINE_RUN_CONTEXT_TYPE_NAME:
contexts_to_add.append((
constants.PIPELINE_RUN_CONTEXT_TYPE_NAME,
context_spec.name,
))
contexts_to_add.append(
(constants.PIPELINE_RUN_CONTEXT_TYPE_NAME, context_spec.name, None)
)
_compile_channel_pb_contexts(contexts_to_add, result_input_channel)

elif isinstance(channel, channel_types.Channel):
Expand Down
Loading

0 comments on commit 1560977

Please sign in to comment.