From 07b193d9fa08e65a1333404df015848b8bf44e4d Mon Sep 17 00:00:00 2001 From: tfx-team Date: Wed, 29 May 2024 13:15:06 -0700 Subject: [PATCH] Add tag to ExternalPipelineChannel so we can get artifacts by tags. PiperOrigin-RevId: 638389192 --- tfx/dsl/compiler/compiler_test.py | 2 + tfx/dsl/compiler/node_inputs_compiler.py | 80 ++++-- tfx/dsl/compiler/node_inputs_compiler_test.py | 251 ++++++++++++++++++ .../testdata/consumer_pipeline_with_tags.py | 37 +++ ...sumer_pipeline_with_tags_input_v2_ir.pbtxt | 210 +++++++++++++++ tfx/types/channel.py | 15 +- tfx/types/channel_utils.py | 27 +- 7 files changed, 605 insertions(+), 17 deletions(-) create mode 100644 tfx/dsl/compiler/testdata/consumer_pipeline_with_tags.py create mode 100644 tfx/dsl/compiler/testdata/consumer_pipeline_with_tags_input_v2_ir.pbtxt diff --git a/tfx/dsl/compiler/compiler_test.py b/tfx/dsl/compiler/compiler_test.py index 013e895b0f..b9e5cdf6bb 100644 --- a/tfx/dsl/compiler/compiler_test.py +++ b/tfx/dsl/compiler/compiler_test.py @@ -33,6 +33,7 @@ from tfx.dsl.compiler.testdata import conditional_pipeline from tfx.dsl.compiler.testdata import consumer_pipeline from tfx.dsl.compiler.testdata import consumer_pipeline_different_project +from tfx.dsl.compiler.testdata import consumer_pipeline_with_tags from tfx.dsl.compiler.testdata import dynamic_exec_properties_pipeline from tfx.dsl.compiler.testdata import external_artifacts_pipeline from tfx.dsl.compiler.testdata import foreach_pipeline @@ -143,6 +144,7 @@ def _get_pipeline_ir(self, filename: str) -> pipeline_pb2.Pipeline: consumer_pipeline, external_artifacts_pipeline, consumer_pipeline_different_project, + consumer_pipeline_with_tags, ]) ) def testCompile( diff --git a/tfx/dsl/compiler/node_inputs_compiler.py b/tfx/dsl/compiler/node_inputs_compiler.py index bd6423ecae..33ee56f7dc 100644 --- a/tfx/dsl/compiler/node_inputs_compiler.py +++ b/tfx/dsl/compiler/node_inputs_compiler.py @@ -13,8 +13,9 @@ # limitations under the License. """Compiler submodule specialized for NodeInputs.""" -from collections.abc import Iterable -from typing import Type, cast +from collections.abc import Iterable, Sequence +import functools +from typing import Optional, Type, cast from tfx import types from tfx.dsl.compiler import compiler_context @@ -41,6 +42,8 @@ from ml_metadata.proto import metadata_store_pb2 +_PropertyPredicate = pipeline_pb2.PropertyPredicate + def _get_tfx_value(value: str) -> pipeline_pb2.Value: """Returns a TFX Value containing the provided string.""" @@ -135,14 +138,24 @@ def compile_op_node(op_node: resolver_op.OpNode): def _compile_channel_pb_contexts( - context_types_and_names: Iterable[tuple[str, pipeline_pb2.Value]], + # TODO(b/264728226) Can flatten these args to make it more readable. + types_values_and_predicates: Iterable[ + tuple[str, pipeline_pb2.Value, Optional[_PropertyPredicate]] + ], result: pipeline_pb2.InputSpec.Channel, ): """Adds contexts to the channel.""" - for context_type, context_value in context_types_and_names: + for ( + context_type, + context_value, + predicate, + ) in types_values_and_predicates: ctx = result.context_queries.add() ctx.type.name = context_type - ctx.name.CopyFrom(context_value) + if context_value: + ctx.name.CopyFrom(context_value) + if predicate: + ctx.property_predicate.CopyFrom(predicate) def _compile_channel_pb( @@ -157,9 +170,11 @@ def _compile_channel_pb( result.artifact_query.type.CopyFrom(mlmd_artifact_type) result.artifact_query.type.ClearField('properties') - contexts_types_and_values = [ - (constants.PIPELINE_CONTEXT_TYPE_NAME, _get_tfx_value(pipeline_name)) - ] + contexts_types_and_values = [( + constants.PIPELINE_CONTEXT_TYPE_NAME, + _get_tfx_value(pipeline_name), + None, + )] if node_id: contexts_types_and_values.append( ( @@ -167,6 +182,7 @@ def _compile_channel_pb( _get_tfx_value( compiler_utils.node_context_name(pipeline_name, node_id) ), + None, ), ) _compile_channel_pb_contexts(contexts_types_and_values, result) @@ -175,6 +191,37 @@ def _compile_channel_pb( result.output_key = output_key +def _construct_predicate( + predicate_names_and_values: Sequence[tuple[str, metadata_store_pb2.Value]], +) -> Optional[_PropertyPredicate]: + """Constructs a PropertyPredicate from a list of name and value pairs.""" + if not predicate_names_and_values: + return None + + predicates = [] + for name, predicate_value in predicate_names_and_values: + predicates.append( + _PropertyPredicate( + value_comparator=_PropertyPredicate.ValueComparator( + property_name=name, + op=_PropertyPredicate.ValueComparator.Op.EQ, + target_value=pipeline_pb2.Value(field_value=predicate_value), + is_custom_property=True, + ) + ) + ) + + def _make_and(lhs, rhs): + return _PropertyPredicate( + binary_logical_operator=_PropertyPredicate.BinaryLogicalOperator( + op=_PropertyPredicate.BinaryLogicalOperator.AND, lhs=lhs, rhs=rhs + ) + ) + + if predicates: + return functools.reduce(_make_and, predicates) + + def _compile_input_spec( *, pipeline_ctx: compiler_context.PipelineContext, @@ -206,7 +253,7 @@ def _compile_input_spec( # from the same resolver function output. if not hidden: # Overwrite hidden = False even for already compiled channel, this is - # because we don't know the input should truely be hidden until the + # because we don't know the input should truly be hidden until the # channel turns out not to be. result.inputs[input_key].hidden = False return @@ -240,11 +287,15 @@ def _compile_input_spec( result=result_input_channel, ) - if channel.pipeline_run_id: + if channel.pipeline_run_id or channel.run_context_predicates: + predicate = _construct_predicate(channel.run_context_predicates) _compile_channel_pb_contexts( [( constants.PIPELINE_RUN_CONTEXT_TYPE_NAME, - _get_tfx_value(channel.pipeline_run_id), + _get_tfx_value( + channel.pipeline_run_id if channel.pipeline_run_id else '' + ), + predicate, )], result_input_channel, ) @@ -290,10 +341,9 @@ def _compile_input_spec( contexts_to_add = [] for context_spec in node_contexts.contexts: if context_spec.type.name == constants.PIPELINE_RUN_CONTEXT_TYPE_NAME: - contexts_to_add.append(( - constants.PIPELINE_RUN_CONTEXT_TYPE_NAME, - context_spec.name, - )) + contexts_to_add.append( + (constants.PIPELINE_RUN_CONTEXT_TYPE_NAME, context_spec.name, None) + ) _compile_channel_pb_contexts(contexts_to_add, result_input_channel) elif isinstance(channel, channel_types.Channel): diff --git a/tfx/dsl/compiler/node_inputs_compiler_test.py b/tfx/dsl/compiler/node_inputs_compiler_test.py index 5bb2844e4f..f554bb3826 100644 --- a/tfx/dsl/compiler/node_inputs_compiler_test.py +++ b/tfx/dsl/compiler/node_inputs_compiler_test.py @@ -37,6 +37,7 @@ from tfx.types import standard_artifacts from google.protobuf import text_format +from ml_metadata.proto import metadata_store_pb2 class DummyArtifact(types.Artifact): @@ -292,6 +293,256 @@ def testCompileInputGraph(self): ctx, node, channel, result) self.assertEqual(input_graph_id, second_input_graph_id) + def testCompilePropertyPredicateForTags(self): + with self.subTest('zero tag'): + consumer = DummyNode( + 'MyConsumer', + inputs={ + 'input_key': channel_types.ExternalPipelineChannel( + artifact_type=DummyArtifact, + owner='MyProducer', + pipeline_name='pipeline_name', + producer_component_id='producer_component_id', + output_key='z', + run_context_predicates=[], + ) + }, + ) + result = self._compile_node_inputs(consumer, components=[consumer]) + self.assertLen(result.inputs['input_key'].channels, 1) + self.assertProtoEquals( + """ + context_queries { + type { + name: "pipeline" + } + name { + field_value { + string_value: "pipeline_name" + } + } + } + context_queries { + type { + name: "node" + } + name { + field_value { + string_value: "pipeline_name.producer_component_id" + } + } + } + artifact_query { + type { + name: "Dummy" + } + } + output_key: "z" + metadata_connection_config { + [type.googleapis.com/tfx.orchestration.MLMDServiceConfig] { + owner: "MyProducer" + name: "pipeline_name" + } + } + """, + result.inputs['input_key'].channels[0], + ) + + with self.subTest('one tag'): + consumer = DummyNode( + 'MyConsumer', + inputs={ + 'input_key': channel_types.ExternalPipelineChannel( + artifact_type=DummyArtifact, + owner='MyProducer', + pipeline_name='pipeline_name', + producer_component_id='producer_component_id', + output_key='z', + run_context_predicates=[ + ('tag_1', metadata_store_pb2.Value(bool_value=True)) + ], + ) + }, + ) + + result = self._compile_node_inputs(consumer, components=[consumer]) + + self.assertLen(result.inputs['input_key'].channels, 1) + self.assertProtoEquals( + """ + context_queries { + type { + name: "pipeline" + } + name { + field_value { + string_value: "pipeline_name" + } + } + } + context_queries { + type { + name: "node" + } + name { + field_value { + string_value: "pipeline_name.producer_component_id" + } + } + } + context_queries { + type { + name: "pipeline_run" + } + name { + field_value { + string_value: "" + } + } + property_predicate { + value_comparator { + property_name: "tag_1" + target_value { + field_value { + bool_value: true + } + } + op: EQ + is_custom_property: true + } + } + } + artifact_query { + type { + name: "Dummy" + } + } + output_key: "z" + metadata_connection_config { + [type.googleapis.com/tfx.orchestration.MLMDServiceConfig] { + owner: "MyProducer" + name: "pipeline_name" + } + } + """, + result.inputs['input_key'].channels[0], + ) + + with self.subTest('three tags'): + consumer = DummyNode( + 'MyConsumer', + inputs={ + 'input_key': channel_types.ExternalPipelineChannel( + artifact_type=DummyArtifact, + owner='MyProducer', + pipeline_name='pipeline_name', + producer_component_id='producer_component_id', + output_key='z', + run_context_predicates=[ + ('tag_1', metadata_store_pb2.Value(bool_value=True)), + ('tag_2', metadata_store_pb2.Value(bool_value=True)), + ('tag_3', metadata_store_pb2.Value(bool_value=True)), + ], + ) + }, + ) + + result = self._compile_node_inputs(consumer, components=[consumer]) + self.assertLen(result.inputs['input_key'].channels, 1) + self.assertProtoEquals( + """ + context_queries { + type { + name: "pipeline" + } + name { + field_value { + string_value: "pipeline_name" + } + } + } + context_queries { + type { + name: "node" + } + name { + field_value { + string_value: "pipeline_name.producer_component_id" + } + } + } + context_queries { + type { + name: "pipeline_run" + } + name { + field_value { + string_value: "" + } + } + property_predicate { + binary_logical_operator { + op: AND + lhs { + binary_logical_operator { + op: AND + lhs { + value_comparator { + property_name: "tag_1" + target_value { + field_value { + bool_value: true + } + } + op: EQ + is_custom_property: true + } + } + rhs { + value_comparator { + property_name: "tag_2" + target_value { + field_value { + bool_value: true + } + } + op: EQ + is_custom_property: true + } + } + } + } + rhs { + value_comparator { + property_name: "tag_3" + target_value { + field_value { + bool_value: true + } + } + op: EQ + is_custom_property: true + } + } + } + } + } + artifact_query { + type { + name: "Dummy" + } + } + output_key: "z" + metadata_connection_config { + [type.googleapis.com/tfx.orchestration.MLMDServiceConfig] { + owner: "MyProducer" + name: "pipeline_name" + } + } + """, + result.inputs['input_key'].channels[0], + ) + def testCompileInputGraphRef(self): with dummy_artifact_list.given_output_type(DummyArtifact): x1 = dummy_artifact_list() diff --git a/tfx/dsl/compiler/testdata/consumer_pipeline_with_tags.py b/tfx/dsl/compiler/testdata/consumer_pipeline_with_tags.py new file mode 100644 index 0000000000..de4b48ce51 --- /dev/null +++ b/tfx/dsl/compiler/testdata/consumer_pipeline_with_tags.py @@ -0,0 +1,37 @@ +# Copyright 2022 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test pipeline for tfx.dsl.compiler.compiler.""" + +from tfx.components import StatisticsGen +from tfx.orchestration import pipeline +from tfx.types import channel_utils +from tfx.types import standard_artifacts + + +def create_test_pipeline(): + """Builds a consumer pipeline that gets artifacts from another project.""" + external_examples = channel_utils.external_pipeline_artifact_query( + artifact_type=standard_artifacts.Examples, + owner='owner', + pipeline_name='producer-pipeline', + producer_component_id='producer-component-id', + output_key='output-key', + pipeline_run_tags=['tag1', 'tag2', 'tag3'], + ) + + statistics_gen = StatisticsGen(examples=external_examples) + + return pipeline.Pipeline( + pipeline_name='consumer-pipeline', components=[statistics_gen] + ) diff --git a/tfx/dsl/compiler/testdata/consumer_pipeline_with_tags_input_v2_ir.pbtxt b/tfx/dsl/compiler/testdata/consumer_pipeline_with_tags_input_v2_ir.pbtxt new file mode 100644 index 0000000000..826f97bc60 --- /dev/null +++ b/tfx/dsl/compiler/testdata/consumer_pipeline_with_tags_input_v2_ir.pbtxt @@ -0,0 +1,210 @@ +pipeline_info { + id: "consumer-pipeline" +} +nodes { + pipeline_node { + node_info { + type { + name: "tfx.components.statistics_gen.component.StatisticsGen" + base_type: PROCESS + } + id: "StatisticsGen" + } + contexts { + contexts { + type { + name: "pipeline" + } + name { + field_value { + string_value: "consumer-pipeline" + } + } + } + contexts { + type { + name: "pipeline_run" + } + name { + runtime_parameter { + name: "pipeline-run-id" + type: STRING + } + } + } + contexts { + type { + name: "node" + } + name { + field_value { + string_value: "consumer-pipeline.StatisticsGen" + } + } + } + } + inputs { + inputs { + key: "examples" + value { + channels { + context_queries { + type { + name: "pipeline" + } + name { + field_value { + string_value: "producer-pipeline" + } + } + } + context_queries { + type { + name: "node" + } + name { + field_value { + string_value: "producer-pipeline.producer-component-id" + } + } + } + context_queries { + type { + name: "pipeline_run" + } + name { + field_value { + string_value: "" + } + } + property_predicate { + binary_logical_operator { + op: AND + lhs { + binary_logical_operator { + op: AND + lhs { + value_comparator { + property_name: "__tag_tag1__" + target_value { + field_value { + bool_value: true + } + } + op: EQ + is_custom_property: true + } + } + rhs { + value_comparator { + property_name: "__tag_tag2__" + target_value { + field_value { + bool_value: true + } + } + op: EQ + is_custom_property: true + } + } + } + } + rhs { + value_comparator { + property_name: "__tag_tag3__" + target_value { + field_value { + bool_value: true + } + } + op: EQ + is_custom_property: true + } + } + } + } + } + artifact_query { + type { + name: "Examples" + base_type: DATASET + } + } + output_key: "output-key" + metadata_connection_config { + [type.googleapis.com/tfx.orchestration.MLMDServiceConfig] { + owner: "owner" + name: "producer-pipeline" + } + } + } + min_count: 1 + } + } + } + outputs { + outputs { + key: "statistics" + value { + artifact_spec { + type { + name: "ExampleStatistics" + properties { + key: "span" + value: INT + } + properties { + key: "split_names" + value: STRING + } + base_type: STATISTICS + } + } + } + } + } + parameters { + parameters { + key: "exclude_splits" + value { + field_value { + string_value: "[]" + } + } + } + } + execution_options { + caching_options { + } + } + } +} +runtime_spec { + pipeline_root { + runtime_parameter { + name: "pipeline-root" + type: STRING + } + } + pipeline_run_id { + runtime_parameter { + name: "pipeline-run-id" + type: STRING + } + } +} +execution_mode: SYNC +deployment_config { + [type.googleapis.com/tfx.orchestration.IntermediateDeploymentConfig] { + executor_specs { + key: "StatisticsGen" + value { + [type.googleapis.com/tfx.orchestration.executable_spec.BeamExecutableSpec] { + python_executor_spec { + class_path: "tfx.components.statistics_gen.executor.Executor" + } + } + } + } + } +} diff --git a/tfx/types/channel.py b/tfx/types/channel.py index 9c79ea7e4b..de0c4c9a1d 100644 --- a/tfx/types/channel.py +++ b/tfx/types/channel.py @@ -722,6 +722,9 @@ def __init__( producer_component_id: str, output_key: str, pipeline_run_id: str = '', + run_context_predicates: Sequence[ + tuple[str, metadata_store_pb2.Value] + ] = (), ): """Initialization of ExternalPipelineChannel. @@ -733,13 +736,22 @@ def __init__( output_key: The output key when producer component produces the artifacts in this Channel. pipeline_run_id: (Optional) Pipeline run id the artifacts belong to. + run_context_predicates: (Optional) A list of run context property + predicates to filter run contexts. """ super().__init__(type=artifact_type) + + if pipeline_run_id and run_context_predicates: + raise ValueError( + 'pipeline_run_id and run_context_predicates cannot be both set.' + ) + self.owner = owner self.pipeline_name = pipeline_name self.producer_component_id = producer_component_id self.output_key = output_key self.pipeline_run_id = pipeline_run_id + self.run_context_predicates = run_context_predicates def get_data_dependent_node_ids(self) -> Set[str]: return set() @@ -751,7 +763,8 @@ def __repr__(self) -> str: f'pipeline_name={self.pipeline_name}, ' f'producer_component_id={self.producer_component_id}, ' f'output_key={self.output_key}, ' - f'pipeline_run_id={self.pipeline_run_id})' + f'pipeline_run_id={self.pipeline_run_id}), ' + f'run_context_predicates={self.run_context_predicates}' ) diff --git a/tfx/types/channel_utils.py b/tfx/types/channel_utils.py index 3712553833..b9240cc1bd 100644 --- a/tfx/types/channel_utils.py +++ b/tfx/types/channel_utils.py @@ -33,6 +33,8 @@ from tfx.types import artifact from tfx.types import channel +from ml_metadata.proto import metadata_store_pb2 + class ChannelForTesting(channel.BaseChannel): """Dummy channel for testing.""" @@ -149,6 +151,7 @@ def external_pipeline_artifact_query( producer_component_id: str, output_key: str, pipeline_run_id: str = '', + pipeline_run_tags: Sequence[str] = (), ) -> channel.ExternalPipelineChannel: """Helper function to construct a query to get artifacts from an external pipeline. @@ -160,16 +163,37 @@ def external_pipeline_artifact_query( output_key: The output key when producer component produces the artifacts in this Channel. pipeline_run_id: (Optional) Pipeline run id the artifacts belong to. + pipeline_run_tags: (Optional) A list of tags the artifacts belong to. It is + an AND relationship between tags. For example, if tags=['tag1', 'tag2'], + then only artifacts belonging to the run with both 'tag1' and 'tag2' will + be returned. Only one of pipeline_run_id and pipeline_run_tags can be set. Returns: channel.ExternalPipelineChannel instance. Raises: - ValueError, if owner or pipeline_name is missing. + ValueError, if owner or pipeline_name is missing, or both pipeline_run_id + and pipeline_run_tags are set. """ if not owner or not pipeline_name: raise ValueError('owner or pipeline_name is missing.') + if pipeline_run_id and pipeline_run_tags: + raise ValueError( + 'pipeline_run_id and pipeline_run_tags cannot be both set.' + ) + + run_context_predicates = [] + for tag in pipeline_run_tags: + # TODO(b/264728226): Find a better way to construct the tag name that used + # in MLMD. Tag names that used in MLMD are constructed in tflex_mlmd_api.py, + # but it is not visible in this file. + mlmd_store_tag = '__tag_' + tag + '__' + run_context_predicates.append(( + mlmd_store_tag, + metadata_store_pb2.Value(bool_value=True), + )) + return channel.ExternalPipelineChannel( artifact_type=artifact_type, owner=owner, @@ -177,6 +201,7 @@ def external_pipeline_artifact_query( producer_component_id=producer_component_id, output_key=output_key, pipeline_run_id=pipeline_run_id, + run_context_predicates=run_context_predicates, )