-
Notifications
You must be signed in to change notification settings - Fork 706
/
container_component.py
120 lines (104 loc) · 4.1 KB
/
container_component.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions for creating container components."""
from typing import Any, Callable, Dict, List, Optional
from tfx.dsl.component.experimental import component_utils
from tfx.dsl.component.experimental import executor_specs
from tfx.dsl.component.experimental import placeholders
from tfx.dsl.components.base import base_component
from tfx.types import channel_utils
from tfx.types import component_spec
def create_container_component(
name: str,
image: str,
command: List[placeholders.CommandlineArgumentType],
inputs: Optional[Dict[str, Any]] = None,
outputs: Optional[Dict[str, Any]] = None,
parameters: Optional[Dict[str, Any]] = None,
) -> Callable[..., base_component.BaseComponent]:
"""Creates a container-based component.
Args:
name: The name of the component
image: Container image name.
command: Container entrypoint command-line. Not executed within a shell. The
command-line can use placeholder objects that will be replaced at the
compilation time. The placeholder objects can be imported from
tfx.dsl.component.experimental.placeholders. Note that Jinja templates are
not supported.
inputs: The list of component inputs
outputs: The list of component outputs
parameters: The list of component parameters
Returns:
Component that can be instantiated and user inside pipeline.
Example:
```
component = create_container_component(
name='TrainModel',
inputs={
'training_data': Dataset,
},
outputs={
'model': Model,
},
parameters={
'num_training_steps': int,
},
image='gcr.io/my-project/my-trainer',
command=[
'python3', 'my_trainer',
'--training_data_uri', InputUriPlaceholder('training_data'),
'--model_uri', OutputUriPlaceholder('model'),
'--num_training-steps', InputValuePlaceholder('num_training_steps'),
]
)
```
"""
if not name:
raise ValueError('Component name cannot be empty.')
if inputs is None:
inputs = {}
if outputs is None:
outputs = {}
if parameters is None:
parameters = {}
input_channel_parameters = {}
output_channel_parameters = {}
output_channels = {}
execution_parameters = {}
for input_name, channel_type in inputs.items():
# TODO(b/155804245) Sanitize the names so that they're valid python names
input_channel_parameters[input_name] = (
component_spec.ChannelParameter(type=channel_type,))
for output_name, channel_type in outputs.items():
# TODO(b/155804245) Sanitize the names so that they're valid python names
output_channel_parameters[output_name] = (
component_spec.ChannelParameter(type=channel_type))
artifact = channel_type()
channel = channel_utils.as_channel([artifact])
output_channels[output_name] = channel
for param_name, parameter_type in parameters.items():
# TODO(b/155804245) Sanitize the names so that they're valid python names
execution_parameters[param_name] = (
component_spec.ExecutionParameter(type=parameter_type))
default_init_args = {**output_channels}
return component_utils.create_tfx_component_class(
name=name,
tfx_executor_spec=executor_specs.TemplatedExecutorContainerSpec(
image=image,
command=command,
),
input_channel_parameters=input_channel_parameters,
output_channel_parameters=output_channel_parameters,
execution_parameters=execution_parameters,
default_init_args=default_init_args)