Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tag to ExternalPipelineChannel so we can get artifacts by tags. #6795

Merged
merged 1 commit into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions tfx/dsl/compiler/compiler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from tfx.dsl.compiler.testdata import conditional_pipeline
from tfx.dsl.compiler.testdata import consumer_pipeline
from tfx.dsl.compiler.testdata import consumer_pipeline_different_project
from tfx.dsl.compiler.testdata import consumer_pipeline_with_tags
from tfx.dsl.compiler.testdata import dynamic_exec_properties_pipeline
from tfx.dsl.compiler.testdata import external_artifacts_pipeline
from tfx.dsl.compiler.testdata import foreach_pipeline
Expand Down Expand Up @@ -143,6 +144,7 @@ def _get_pipeline_ir(self, filename: str) -> pipeline_pb2.Pipeline:
consumer_pipeline,
external_artifacts_pipeline,
consumer_pipeline_different_project,
consumer_pipeline_with_tags,
])
)
def testCompile(
Expand Down
80 changes: 65 additions & 15 deletions tfx/dsl/compiler/node_inputs_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
# limitations under the License.
"""Compiler submodule specialized for NodeInputs."""

from collections.abc import Iterable
from typing import Type, cast
from collections.abc import Iterable, Sequence
import functools
from typing import Optional, Type, cast

from tfx import types
from tfx.dsl.compiler import compiler_context
Expand All @@ -41,6 +42,8 @@

from ml_metadata.proto import metadata_store_pb2

_PropertyPredicate = pipeline_pb2.PropertyPredicate


def _get_tfx_value(value: str) -> pipeline_pb2.Value:
"""Returns a TFX Value containing the provided string."""
Expand Down Expand Up @@ -135,14 +138,24 @@ def compile_op_node(op_node: resolver_op.OpNode):


def _compile_channel_pb_contexts(
context_types_and_names: Iterable[tuple[str, pipeline_pb2.Value]],
# TODO(b/264728226) Can flatten these args to make it more readable.
types_values_and_predicates: Iterable[
tuple[str, pipeline_pb2.Value, Optional[_PropertyPredicate]]
],
result: pipeline_pb2.InputSpec.Channel,
):
"""Adds contexts to the channel."""
for context_type, context_value in context_types_and_names:
for (
context_type,
context_value,
predicate,
) in types_values_and_predicates:
ctx = result.context_queries.add()
ctx.type.name = context_type
ctx.name.CopyFrom(context_value)
if context_value:
ctx.name.CopyFrom(context_value)
if predicate:
ctx.property_predicate.CopyFrom(predicate)


def _compile_channel_pb(
Expand All @@ -157,16 +170,19 @@ def _compile_channel_pb(
result.artifact_query.type.CopyFrom(mlmd_artifact_type)
result.artifact_query.type.ClearField('properties')

contexts_types_and_values = [
(constants.PIPELINE_CONTEXT_TYPE_NAME, _get_tfx_value(pipeline_name))
]
contexts_types_and_values = [(
constants.PIPELINE_CONTEXT_TYPE_NAME,
_get_tfx_value(pipeline_name),
None,
)]
if node_id:
contexts_types_and_values.append(
(
constants.NODE_CONTEXT_TYPE_NAME,
_get_tfx_value(
compiler_utils.node_context_name(pipeline_name, node_id)
),
None,
),
)
_compile_channel_pb_contexts(contexts_types_and_values, result)
Expand All @@ -175,6 +191,37 @@ def _compile_channel_pb(
result.output_key = output_key


def _construct_predicate(
predicate_names_and_values: Sequence[tuple[str, metadata_store_pb2.Value]],
) -> Optional[_PropertyPredicate]:
"""Constructs a PropertyPredicate from a list of name and value pairs."""
if not predicate_names_and_values:
return None

predicates = []
for name, predicate_value in predicate_names_and_values:
predicates.append(
_PropertyPredicate(
value_comparator=_PropertyPredicate.ValueComparator(
property_name=name,
op=_PropertyPredicate.ValueComparator.Op.EQ,
target_value=pipeline_pb2.Value(field_value=predicate_value),
is_custom_property=True,
)
)
)

def _make_and(lhs, rhs):
return _PropertyPredicate(
binary_logical_operator=_PropertyPredicate.BinaryLogicalOperator(
op=_PropertyPredicate.BinaryLogicalOperator.AND, lhs=lhs, rhs=rhs
)
)

if predicates:
return functools.reduce(_make_and, predicates)


def _compile_input_spec(
*,
pipeline_ctx: compiler_context.PipelineContext,
Expand Down Expand Up @@ -206,7 +253,7 @@ def _compile_input_spec(
# from the same resolver function output.
if not hidden:
# Overwrite hidden = False even for already compiled channel, this is
# because we don't know the input should truely be hidden until the
# because we don't know the input should truly be hidden until the
# channel turns out not to be.
result.inputs[input_key].hidden = False
return
Expand Down Expand Up @@ -240,11 +287,15 @@ def _compile_input_spec(
result=result_input_channel,
)

if channel.pipeline_run_id:
if channel.pipeline_run_id or channel.run_context_predicates:
predicate = _construct_predicate(channel.run_context_predicates)
_compile_channel_pb_contexts(
[(
constants.PIPELINE_RUN_CONTEXT_TYPE_NAME,
_get_tfx_value(channel.pipeline_run_id),
_get_tfx_value(
channel.pipeline_run_id if channel.pipeline_run_id else ''
),
predicate,
)],
result_input_channel,
)
Expand Down Expand Up @@ -290,10 +341,9 @@ def _compile_input_spec(
contexts_to_add = []
for context_spec in node_contexts.contexts:
if context_spec.type.name == constants.PIPELINE_RUN_CONTEXT_TYPE_NAME:
contexts_to_add.append((
constants.PIPELINE_RUN_CONTEXT_TYPE_NAME,
context_spec.name,
))
contexts_to_add.append(
(constants.PIPELINE_RUN_CONTEXT_TYPE_NAME, context_spec.name, None)
)
_compile_channel_pb_contexts(contexts_to_add, result_input_channel)

elif isinstance(channel, channel_types.Channel):
Expand Down