diff --git a/tfx/dsl/input_resolution/ops/latest_policy_model_op.py b/tfx/dsl/input_resolution/ops/latest_policy_model_op.py index c9d8be4842..df642da35e 100644 --- a/tfx/dsl/input_resolution/ops/latest_policy_model_op.py +++ b/tfx/dsl/input_resolution/ops/latest_policy_model_op.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Module for LatestPolicyModel operator.""" + import collections import enum from typing import Dict @@ -23,6 +24,7 @@ from tfx.orchestration.portable.mlmd import event_lib from tfx.orchestration.portable.mlmd import filter_query_builder as q from tfx.types import artifact_utils +from tfx.types import external_artifact_utils from tfx.utils import typing_utils from ml_metadata.proto import metadata_store_pb2 @@ -344,7 +346,17 @@ def apply(self, input_dict: typing_utils.ArtifactMultiMap): input_child_artifacts = input_dict.get( ops_utils.MODEL_BLESSSING_KEY, [] ) + input_dict.get(ops_utils.MODEL_INFRA_BLESSING_KEY, []) - input_child_artifact_ids = set([a.id for a in input_child_artifacts]) + + input_child_artifact_ids = set() + for a in input_child_artifacts: + if a.is_external: + input_child_artifact_ids.add( + external_artifact_utils.get_id_from_external_id( + a.mlmd_artifact.external_id + ) + ) + else: + input_child_artifact_ids.add(a.id) # If the ModelBlessing and ModelInfraBlessing lists are empty, then no # child artifacts can be considered and we raise a SkipSignal. This can @@ -372,8 +384,38 @@ def apply(self, input_dict: typing_utils.ArtifactMultiMap): # There could be multiple events with the same execution ID but different # artifact IDs (e.g. model and baseline_model passed to an Evaluator), so we - # keep the values of model_artifact_ids_by_execution_id as sets. - model_artifact_ids = sorted(set(m.id for m in models)) + # keep the values of model_artifact_ids as sets. + are_models_external = [m.is_external for m in models] + if any(are_models_external) and not all(are_models_external): + raise exceptions.InvalidArgument( + 'Inputs to the LastestPolicyModel are from both current pipeline and' + ' external pipeline. LastestPolicyModel does not support such usage.' + ) + if all(are_models_external): + pipeline_assets = set([ + external_artifact_utils.get_pipeline_asset_from_external_id( + m.mlmd_artifact.external_id + ) + for m in models + ]) + if len(pipeline_assets) != 1: + raise exceptions.InvalidArgument( + 'Input models to the LastestPolicyModel are from multiple' + ' pipelines. LastestPolicyModel does not support such usage.' + ) + + model_by_external_id = {m.mlmd_artifact.external_id: m for m in models} + deduped_models = list(model_by_external_id.values()) + model_artifact_ids = sorted( + set([ + external_artifact_utils.get_id_from_external_id(i) + for i in model_by_external_id.keys() + ]) + ) + else: + model_by_id = {m.id: m for m in models} + deduped_models = list(model_by_id.values()) + model_artifact_ids = sorted(set(model_by_id.keys())) downstream_artifact_type_names_filter_query = q.to_sql_string([ ops_utils.MODEL_BLESSING_TYPE_NAME, @@ -420,7 +462,7 @@ def event_filter(event): mlmd_resolver = metadata_resolver.MetadataResolver(self.context.store) # Populate the ModelRelations associated with each Model artifact and its # children. - model_relations_by_model_artifact_id = collections.defaultdict( + model_relations_by_model_identifier = collections.defaultdict( ModelRelations ) artifact_type_by_name: Dict[str, metadata_store_pb2.ArtifactType] = {} @@ -429,34 +471,34 @@ def event_filter(event): # fetching downstream artifacts, because # `get_downstream_artifacts_by_artifact_ids()` supports at most 100 ids # as starting artifact ids. - for id_index in range(0, len(model_artifact_ids), ops_utils.BATCH_SIZE): - batch_model_artifact_ids = model_artifact_ids[ + for id_index in range(0, len(deduped_models), ops_utils.BATCH_SIZE): + batch_model_artifacts = deduped_models[ id_index : id_index + ops_utils.BATCH_SIZE ] # Set `max_num_hops` to 50, which should be enough for this use case. - batch_downstream_artifacts_and_types_by_model_ids = ( - mlmd_resolver.get_downstream_artifacts_by_artifact_ids( - batch_model_artifact_ids, + batch_downstream_artifacts_and_types_by_model_identifier = ( + mlmd_resolver.get_downstream_artifacts_by_artifacts( + batch_model_artifacts, max_num_hops=ops_utils.LATEST_POLICY_MODEL_OP_MAX_NUM_HOPS, filter_query=filter_query, event_filter=event_filter, ) ) for ( - model_artifact_id, + model_identifier, artifacts_and_types, - ) in batch_downstream_artifacts_and_types_by_model_ids.items(): + ) in batch_downstream_artifacts_and_types_by_model_identifier.items(): for downstream_artifact, artifact_type in artifacts_and_types: artifact_type_by_name[artifact_type.name] = artifact_type - model_relations = model_relations_by_model_artifact_id[ - model_artifact_id - ] - model_relations.add_downstream_artifact(downstream_artifact) + model_relations_by_model_identifier[ + model_identifier + ].add_downstream_artifact(downstream_artifact) # Find the latest model and ModelRelations that meets the Policy. result = {} for model in models: - model_relations = model_relations_by_model_artifact_id[model.id] + identifier = external_artifact_utils.identifier(model) + model_relations = model_relations_by_model_identifier[identifier] if model_relations.meets_policy(self.policy): result[ops_utils.MODEL_KEY] = [model] break diff --git a/tfx/dsl/input_resolution/resolver_op.py b/tfx/dsl/input_resolution/resolver_op.py index 8594d93b6d..9d9af03fa7 100644 --- a/tfx/dsl/input_resolution/resolver_op.py +++ b/tfx/dsl/input_resolution/resolver_op.py @@ -12,13 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. """Module for ResolverOp and its related definitions.""" + from __future__ import annotations import abc -from typing import Any, Generic, Literal, Mapping, Optional, Sequence, Set, Type, TypeVar, Union +from typing import Any, Generic, Literal, Mapping, Optional, Sequence, Set, Type, TypeVar, Union, cast import attr from tfx import types +from tfx.orchestration import mlmd_connection_manager as mlmd_cm from tfx.proto.orchestration import pipeline_pb2 from tfx.utils import json_utils from tfx.utils import typing_utils @@ -28,13 +30,30 @@ # Mark frozen as context instance may be used across multiple operator # invocations. -@attr.s(auto_attribs=True, frozen=True, kw_only=True) class Context: """Context for running ResolverOp.""" - # MetadataStore for MLMD read access. - store: mlmd.MetadataStore - # TODO(jjong): Add more context such as current pipeline, current pipeline - # run, and current running node information. + + def __init__( + self, + store=mlmd.MetadataStore, + mlmd_handle_like: Optional[mlmd_cm.HandleLike] = None, + ): + self._store = store + self._mlmd_handle_like = mlmd_handle_like + + @property + def store(self): + return self._store + + @property + def mlmd_connection_manager(self): + if isinstance(self._mlmd_handle_like, mlmd_cm.MLMDConnectionManager): + return cast(mlmd_cm.MLMDConnectionManager, self._mlmd_handle_like) + else: + return None + + # # TODO(jjong): Add more context such as current pipeline, current pipeline + # # run, and current running node information. # Note that to use DataType as a generic type parameter (e.g. diff --git a/tfx/orchestration/portable/input_resolution/input_graph_resolver.py b/tfx/orchestration/portable/input_resolution/input_graph_resolver.py index 5c6e04a9a9..667b224a7f 100644 --- a/tfx/orchestration/portable/input_resolution/input_graph_resolver.py +++ b/tfx/orchestration/portable/input_resolution/input_graph_resolver.py @@ -29,14 +29,14 @@ import collections import dataclasses import functools -from typing import Union, Sequence, Mapping, Tuple, List, Iterable, Callable +from typing import Callable, Iterable, List, Mapping, Sequence, Tuple, Union from tfx import types from tfx.dsl.components.common import resolver from tfx.dsl.input_resolution import resolver_op from tfx.dsl.input_resolution.ops import ops from tfx.orchestration import data_types_utils -from tfx.orchestration import metadata +from tfx.orchestration import mlmd_connection_manager as mlmd_cm from tfx.orchestration.portable.input_resolution import exceptions from tfx.proto.orchestration import pipeline_pb2 from tfx.utils import topsort @@ -52,8 +52,12 @@ @dataclasses.dataclass class _Context: - mlmd_handle: metadata.Metadata input_graph: pipeline_pb2.InputGraph + mlmd_handle_like: mlmd_cm.HandleLike + + @property + def mlmd_handle(self): + return mlmd_cm.get_handle(self.mlmd_handle_like) def _topologically_sorted_node_ids( @@ -131,7 +135,12 @@ def _evaluate_op_node( f'nodes[{node_id}] has unknown op_type {op_node.op_type}.') from e if issubclass(op_type, resolver_op.ResolverOp): op: resolver_op.ResolverOp = op_type.create(**kwargs) - op.set_context(resolver_op.Context(store=ctx.mlmd_handle.store)) + op.set_context( + resolver_op.Context( + store=mlmd_cm.get_handle(ctx.mlmd_handle_like).store, + mlmd_handle_like=ctx.mlmd_handle_like, + ) + ) return op.apply(*args) elif issubclass(op_type, resolver.ResolverStrategy): if len(args) != 1: @@ -207,7 +216,7 @@ def new_graph_fn(data: Mapping[str, _Data]): def build_graph_fn( - mlmd_handle: metadata.Metadata, + handle_like: mlmd_cm.HandleLike, input_graph: pipeline_pb2.InputGraph, ) -> Tuple[_GraphFn, List[str]]: """Build a functional interface for the `input_graph`. @@ -222,7 +231,7 @@ def build_graph_fn( z = graph_fn({'x': inputs['x'], 'y': inputs['y']}) Args: - mlmd_handle: A `Metadata` instance. + handle_like: A `mlmd_cm.HandleLike` instance. input_graph: An `pipeline_pb2.InputGraph` proto. Returns: @@ -235,7 +244,7 @@ def build_graph_fn( f'result_node {input_graph.result_node} does not exist in input_graph. ' f'Valid node ids: {list(input_graph.nodes.keys())}') - context = _Context(mlmd_handle=mlmd_handle, input_graph=input_graph) + context = _Context(mlmd_handle_like=handle_like, input_graph=input_graph) input_key_to_node_id = {} for node_id in input_graph.nodes: diff --git a/tfx/orchestration/portable/input_resolution/node_inputs_resolver.py b/tfx/orchestration/portable/input_resolution/node_inputs_resolver.py index cad7d29c25..fee73bda28 100644 --- a/tfx/orchestration/portable/input_resolution/node_inputs_resolver.py +++ b/tfx/orchestration/portable/input_resolution/node_inputs_resolver.py @@ -341,7 +341,7 @@ def _join_artifacts( def _resolve_input_graph_ref( - mlmd_handle: metadata.Metadata, + handle_like: mlmd_cm.HandleLike, node_inputs: pipeline_pb2.NodeInputs, input_key: str, resolved: Dict[str, List[_Entry]], @@ -352,12 +352,12 @@ def _resolve_input_graph_ref( (i.e. `InputGraphRef` with the same `graph_id`). Args: - mlmd_handle: A `Metadata` instance. + handle_like: A `mlmd_cm.HandleLike` instance. node_inputs: A `NodeInputs` proto. input_key: A target input key whose corresponding `InputSpec` has an - `InputGraphRef`. + `InputGraphRef`. resolved: A dict that contains the already resolved inputs, and to which the - resolved result would be written from this function. + resolved result would be written from this function. """ graph_id = node_inputs.inputs[input_key].input_graph_ref.graph_id input_graph = node_inputs.input_graphs[graph_id] @@ -372,7 +372,8 @@ def _resolve_input_graph_ref( } graph_fn, graph_input_keys = input_graph_resolver.build_graph_fn( - mlmd_handle, node_inputs.input_graphs[graph_id]) + handle_like, node_inputs.input_graphs[graph_id] + ) for partition, input_dict in _join_artifacts(resolved, graph_input_keys): result = graph_fn(input_dict) if graph_output_type == _DataType.ARTIFACT_LIST: @@ -514,9 +515,7 @@ def resolve( (partition_utils.NO_PARTITION, _filter_live(artifacts)) ] elif input_spec.input_graph_ref.graph_id: - _resolve_input_graph_ref( - mlmd_cm.get_handle(handle_like), node_inputs, input_key, - resolved) + _resolve_input_graph_ref(handle_like, node_inputs, input_key, resolved) elif input_spec.mixed_inputs.input_keys: _resolve_mixed_inputs(node_inputs, input_key, resolved) elif input_spec.HasField('static_inputs'): diff --git a/tfx/types/external_artifact_utils.py b/tfx/types/external_artifact_utils.py new file mode 100644 index 0000000000..ba0b87c5db --- /dev/null +++ b/tfx/types/external_artifact_utils.py @@ -0,0 +1,35 @@ +# Copyright 2024 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Third party version of external_artifact_utils.py.""" + + +def get_artifact_id_from_external_id(external_id: str): + del external_id + + +def get_pipeline_asset_from_external_id( + external_id: str, +): + del external_id + + +def get_external_connection_config( + external_id: str, +): + del external_id + + +def identifier(artifact): + del artifact