Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Only allow .future() on OutputChannel #6799

Merged
merged 1 commit into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions tfx/dsl/compiler/compiler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,24 @@ def testCompileAdditionalCustomPropertyNameConflictError(self):
def testCompileDynamicExecPropTypeError(self):
dsl_compiler = compiler.Compiler()
test_pipeline = dynamic_exec_properties_pipeline.create_test_pipeline()
upstream_component = next(
c
for c in test_pipeline.components
if isinstance(
c,
type(
dynamic_exec_properties_pipeline.UpstreamComponent(start_num=0)
),
)
)
downstream_component = next(
c for c in test_pipeline.components
if isinstance(c, dynamic_exec_properties_pipeline.DownstreamComponent))
test_wrong_type_channel = channel.Channel(_MyType).future().value
c
for c in test_pipeline.components
if isinstance(c, dynamic_exec_properties_pipeline.DownstreamComponent)
)
test_wrong_type_channel = (
channel.OutputChannel(_MyType, upstream_component, "foo").future().value
)
downstream_component.exec_properties["input_num"] = test_wrong_type_channel
with self.assertRaisesRegex(
ValueError, ".*channel must be of a value artifact type.*"
Expand Down
42 changes: 28 additions & 14 deletions tfx/dsl/compiler/compiler_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,24 @@
import itertools

import tensorflow as tf
from tfx import components
from tfx import types
from tfx.components import CsvExampleGen
from tfx.components import StatisticsGen
from tfx.dsl.compiler import compiler_utils
from tfx.dsl.components.base import base_component
from tfx.dsl.components.base import base_executor
from tfx.dsl.components.base import executor_spec
from tfx.dsl.components.base.testing import test_node
from tfx.dsl.components.common import importer
from tfx.dsl.components.common import resolver
from tfx.dsl.input_resolution.strategies import latest_blessed_model_strategy
from tfx.dsl.placeholder import placeholder as ph
from tfx.orchestration import pipeline
from tfx.proto.orchestration import pipeline_pb2
from tfx.types import channel
from tfx.types import standard_artifacts
from tfx.types.artifact import Artifact
from tfx.types.artifact import Property
from tfx.types.artifact import PropertyType
from tfx.types.channel import Channel
from tfx.types.channel import OutputChannel
from tfx.types.channel_utils import external_pipeline_artifact_query

from google.protobuf import text_format
Expand Down Expand Up @@ -98,7 +97,7 @@ def testIsResolver(self):
strategy_class=latest_blessed_model_strategy.LatestBlessedModelStrategy)
self.assertTrue(compiler_utils.is_resolver(resv))

example_gen = CsvExampleGen(input_base="data_path")
example_gen = components.CsvExampleGen(input_base="data_path")
self.assertFalse(compiler_utils.is_resolver(example_gen))

def testHasResolverNode(self):
Expand All @@ -116,7 +115,7 @@ def testIsImporter(self):
source_uri="uri/to/schema", artifact_type=standard_artifacts.Schema)
self.assertTrue(compiler_utils.is_importer(impt))

example_gen = CsvExampleGen(input_base="data_path")
example_gen = components.CsvExampleGen(input_base="data_path")
self.assertFalse(compiler_utils.is_importer(example_gen))

def testEnsureTopologicalOrder(self):
Expand All @@ -128,9 +127,9 @@ def testEnsureTopologicalOrder(self):
valid_orders = {"abc", "acb"}
for order in itertools.permutations([a, b, c]):
if "".join([c.id for c in order]) in valid_orders:
self.assertTrue(compiler_utils.ensure_topological_order(order))
self.assertTrue(compiler_utils.ensure_topological_order(list(order)))
else:
self.assertFalse(compiler_utils.ensure_topological_order(order))
self.assertFalse(compiler_utils.ensure_topological_order(list(order)))

def testIncompatibleExecutionMode(self):
p = pipeline.Pipeline(
Expand All @@ -143,8 +142,10 @@ def testIncompatibleExecutionMode(self):
compiler_utils.resolve_execution_mode(p)

def testHasTaskDependency(self):
example_gen = CsvExampleGen(input_base="data_path")
statistics_gen = StatisticsGen(examples=example_gen.outputs["examples"])
example_gen = components.CsvExampleGen(input_base="data_path")
statistics_gen = components.StatisticsGen(
examples=example_gen.outputs["examples"]
)
p1 = pipeline.Pipeline(
pipeline_name="fake_name",
pipeline_root="fake_root",
Expand Down Expand Up @@ -204,7 +205,14 @@ class ValidateExecPropertyPlaceholderTest(tf.test.TestCase):
def test_accepts_canonical_dynamic_exec_prop_placeholder(self):
# .future()[0].uri is how we tell users to hook up a dynamic exec prop.
compiler_utils.validate_exec_property_placeholder(
"testkey", Channel(type=_MyType).future()[0].value
"testkey",
channel.OutputChannel(
artifact_type=_MyType,
producer_component=test_node.TestNode("producer"),
output_key="foo",
)
.future()[0]
.value,
)

def test_accepts_complex_exec_prop_placeholder(self):
Expand All @@ -219,7 +227,13 @@ def test_accepts_complex_exec_prop_placeholder(self):
def test_accepts_complex_dynamic_exec_prop_placeholder(self):
compiler_utils.validate_exec_property_placeholder(
"testkey",
Channel(type=_MyType).future()[0].value
channel.OutputChannel(
artifact_type=_MyType,
producer_component=test_node.TestNode("producer"),
output_key="foo",
)
.future()[0]
.value
+ "foo"
+ ph.input("someartifact").uri
+ "/somefile.txt",
Expand Down Expand Up @@ -265,14 +279,14 @@ def test_rejects_exec_property_dependency(self):
)

def testOutputSpecFromChannel_AsyncOutputChannel(self):
channel = OutputChannel(
ch = channel.OutputChannel(
artifact_type=standard_artifacts.Model,
output_key="model",
producer_component="trainer",
is_async=True,
)

actual = compiler_utils.output_spec_from_channel(channel, "trainer")
actual = compiler_utils.output_spec_from_channel(ch, "trainer")
expected = text_format.Parse(
"""
artifact_spec {
Expand Down
31 changes: 31 additions & 0 deletions tfx/dsl/components/base/testing/test_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright 2024 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module to provide a node for tests."""

from tfx.dsl.components.base import base_node


class TestNode(base_node.BaseNode):
"""Node purely for testing, intentionally empty.

DO NOT USE in real pipelines.
"""

inputs = {}
outputs = {}
exec_properties = {}

def __init__(self, name: str):
super().__init__()
self.with_id(name)
9 changes: 6 additions & 3 deletions tfx/orchestration/kubeflow/v2/compiler_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ def build_parameter_type_spec(

def _validate_properties_schema(
instance_schema: str,
properties: Optional[Mapping[str, artifact.PropertyType]] = None):
properties: Optional[Mapping[str, artifact.Property]] = None,
):
"""Validates the declared property types are consistent with the schema.

Args:
Expand Down Expand Up @@ -145,8 +146,10 @@ def _validate_properties_schema(
v.type != artifact.PropertyType.STRING or
schema[k]['type'] == _YAML_DOUBLE_TYPE and
v.type != artifact.PropertyType.FLOAT):
raise TypeError(f'Property type mismatched at {k} for schema: {schema}. '
f'Expected {schema[k]["type"]} but got {v.type}')
raise TypeError(
f'Property type mismatched at {k} for schema: {schema}. Expected'
f' {schema[k]["type"]} but got {v.type}'
)
# pytype: enable=attribute-error # use-enum-overlay


Expand Down
10 changes: 8 additions & 2 deletions tfx/orchestration/kubeflow/v2/compiler_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from absl.testing import parameterized
from kfp.pipeline_spec import pipeline_spec_pb2 as pipeline_pb2
import tensorflow as tf
from tfx.dsl.components.base.testing import test_node
from tfx.dsl.io import fileio
from tfx.orchestration import data_types
from tfx.orchestration.kubeflow.v2 import compiler_utils
Expand Down Expand Up @@ -70,7 +71,11 @@ class _MyArtifactWithProperty(artifact.Artifact):
}


_TEST_CHANNEL = channel.Channel(type=_MyArtifactWithProperty)
_TEST_CHANNEL = channel.OutputChannel(
artifact_type=_MyArtifactWithProperty,
producer_component=test_node.TestNode('producer'),
output_key='foo',
)


class CompilerUtilsTest(tf.test.TestCase):
Expand Down Expand Up @@ -133,7 +138,8 @@ def testCustomArtifactSchemaMismatchFails(self):
with self.assertRaisesRegex(TypeError, 'Property type mismatched at'):
compiler_utils._validate_properties_schema(
_MY_BAD_ARTIFACT_SCHEMA_WITH_PROPERTIES,
_MyArtifactWithProperty.PROPERTIES)
_MyArtifactWithProperty.PROPERTIES,
)

def testBuildParameterTypeSpecLegacy(self):
type_enum = pipeline_pb2.PrimitiveType.PrimitiveTypeEnum
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@
from unittest import mock

import tensorflow as tf
from tfx.dsl.components.base.testing import test_node
from tfx.orchestration.portable.input_resolution import exceptions
from tfx.orchestration.portable.input_resolution import input_graph_resolver
from tfx.orchestration.portable.input_resolution import node_inputs_resolver
from tfx.orchestration.portable.input_resolution import partition_utils
from tfx.orchestration.portable.input_resolution import channel_resolver
from tfx.proto.orchestration import pipeline_pb2
import tfx.types
from tfx.types import channel
from tfx.types import channel_utils
from tfx.utils import test_case_utils

Expand Down Expand Up @@ -76,12 +78,18 @@ def no(nodes, dependencies):
except exceptions.FailedPreconditionError:
self.fail('Expected no cycle but has cycle.')

no('', {})
yes('a', {'a': 'a'})
yes('ab', {'a': 'b', 'b': 'a'})
yes('abc', {'a': 'b', 'b': 'c', 'c': 'a'})
no('abcd', {'a': 'bcd', 'b': '', 'c': '', 'd': ''})
no('abcd', {'a': 'bc', 'b': 'd', 'c': 'd', 'd': ''})
no(list(), {})
yes(list('a'), {'a': list('a')})
yes(list('ab'), {'a': list('b'), 'b': list('a')})
yes(list('abc'), {'a': list('b'), 'b': list('c'), 'c': list('a')})
no(
list('abcd'),
{'a': list('bcd'), 'b': list(''), 'c': list(''), 'd': list('')},
)
no(
list('abcd'),
{'a': list('bc'), 'b': list('d'), 'c': list('d'), 'd': list('')},
)

def testTopologicallySortedInputKeys(self):
node_inputs = self.parse_node_inputs("""
Expand Down Expand Up @@ -264,8 +272,8 @@ def setUp(self):

def mock_channel_resolution_result(self, input_spec, artifacts):
assert len(input_spec.channels) == 1
for channel in input_spec.channels:
channel_key = text_format.MessageToString(channel, as_one_line=True)
for chnl in input_spec.channels:
channel_key = text_format.MessageToString(chnl, as_one_line=True)
self._channel_resolve_result[channel_key] = artifacts

def mock_graph_fn_result(self, input_graph, graph_fn, dependent_inputs=()):
Expand All @@ -275,8 +283,8 @@ def mock_graph_fn_result(self, input_graph, graph_fn, dependent_inputs=()):
def _mock_resolve_union_channels(self, store, channels):
del store # Unused.
result = []
for channel in channels:
channel_key = text_format.MessageToString(channel, as_one_line=True)
for chnl in channels:
channel_key = text_format.MessageToString(chnl, as_one_line=True)
result.extend(self._channel_resolve_result[channel_key])
return result

Expand Down Expand Up @@ -676,15 +684,28 @@ def testConditionals(self):
# Only allows artifact.custom_properties['blessed'] == 1,
# which is a1 and a4.
is_blessed = channel_utils.encode_placeholder_with_channels(
DummyChannel('x').future()[0].custom_property('blessed') == 1,
lambda channel: channel.name,
channel.OutputChannel(
artifact_type=DummyArtifact,
producer_component=test_node.TestNode('foo'),
output_key='x',
)
.future()[0]
.custom_property('blessed')
== 1,
lambda _: 'x',
)

# Only allows artifact.custom_properties['tag'] == 'foo'
# which is a1 and a2.
is_foo = channel_utils.encode_placeholder_with_channels(
(DummyChannel('x').future()[0].custom_property('tag') == 'foo'),
lambda channel: channel.name,
channel.OutputChannel(
artifact_type=DummyArtifact,
producer_component=test_node.TestNode('foo'),
output_key='x',
)
.future()[0]
.custom_property('tag')
== 'foo',
lambda _: 'x',
)

cond_1 = pipeline_pb2.NodeInputs.Conditional(
Expand Down Expand Up @@ -740,8 +761,15 @@ def testConditionals_FalseCondAlwaysReturnsEmpty(self):

# Only allows artifact.custom_properties['blessed'] == 1,
is_blessed = channel_utils.encode_placeholder_with_channels(
DummyChannel('b').future()[0].custom_property('blessed') == 1,
lambda channel: channel.name,
channel.OutputChannel(
artifact_type=DummyArtifact,
producer_component=test_node.TestNode('foo'),
output_key='x',
)
.future()[0]
.custom_property('blessed')
== 1,
lambda _: 'b',
)
cond = pipeline_pb2.NodeInputs.Conditional(
placeholder_expression=is_blessed
Expand Down
8 changes: 7 additions & 1 deletion tfx/types/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def trigger_by_property(self, *property_keys: str):
return self._with_input_trigger(TriggerByProperty(property_keys))

def future(self) -> ChannelWrappedPlaceholder:
return ChannelWrappedPlaceholder(self)
raise NotImplementedError()

def __eq__(self, other):
return self is other
Expand Down Expand Up @@ -557,6 +557,9 @@ def set_external(self, predefined_artifact_uris: List[str]) -> None:
def set_as_async_channel(self) -> None:
self._is_async = True

def future(self) -> ChannelWrappedPlaceholder:
return ChannelWrappedPlaceholder(self)


@doc_controls.do_not_generate_docs
class UnionChannel(BaseChannel):
Expand Down Expand Up @@ -703,6 +706,9 @@ def trigger_by_property(self, *property_keys: str):
'trigger_by_property is not implemented for PipelineInputChannel.'
)

def future(self) -> ChannelWrappedPlaceholder:
return ChannelWrappedPlaceholder(self)


class ExternalPipelineChannel(BaseChannel):
"""Channel subtype that is used to get artifacts from external MLMD db."""
Expand Down
7 changes: 6 additions & 1 deletion tfx/types/channel_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from unittest import mock

import tensorflow as tf
from tfx.dsl.components.base.testing import test_node
from tfx.dsl.input_resolution import resolver_op
from tfx.dsl.placeholder import placeholder
from tfx.types import artifact
Expand Down Expand Up @@ -90,7 +91,11 @@ def testJsonRoundTripUnknownArtifactClass(self):
self.assertTrue(rehydrated.type._AUTOGENERATED)

def testFutureProducesPlaceholder(self):
chnl = channel.Channel(type=_MyType)
chnl = channel.OutputChannel(
artifact_type=_MyType,
producer_component=test_node.TestNode('producer'),
output_key='foo',
)
future = chnl.future()
self.assertIsInstance(future, placeholder.ChannelWrappedPlaceholder)
self.assertIs(future.channel, chnl)
Expand Down
Loading