-
Notifications
You must be signed in to change notification settings - Fork 707
/
base_driver.py
355 lines (311 loc) · 14.2 KB
/
base_driver.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Abstract TFX driver class."""
import os
from typing import Any, Dict, List
import absl
from tfx import types
from tfx.dsl.io import fileio
from tfx.orchestration import data_types
from tfx.orchestration import metadata
from tfx.types import artifact_utils
from tfx.types import channel_utils
def _generate_output_uri(base_output_dir: str,
name: str,
execution_id: int,
is_single_artifact: bool = True,
index: int = 0) -> str:
"""Generate uri for output artifact."""
if is_single_artifact:
# TODO(b/145680633): Consider differentiating different types of uris.
return os.path.join(base_output_dir, name, str(execution_id))
return os.path.join(base_output_dir, name, str(execution_id), str(index))
def _prepare_output_paths(artifact: types.Artifact):
"""Create output directories for output artifact."""
if fileio.exists(artifact.uri):
absl.logging.warning('Output artifact uri %s already exists', artifact.uri)
# TODO(b/158689199): We currently simply return as a short-term workaround
# to unblock execution retires. A comprehensive solution to guarantee
# idempotent executions is needed.
return
# TODO(b/147242148): Introduce principled artifact structure (directory
# or file) definition.
if isinstance(artifact, types.ValueArtifact):
artifact_dir = os.path.dirname(artifact.uri)
else:
artifact_dir = artifact.uri
# TODO(zhitaoli): Consider refactoring this out into something
# which can handle permission bits.
absl.logging.debug('Creating output artifact uri %s as directory',
artifact_dir)
fileio.makedirs(artifact_dir)
class BaseDriver:
"""BaseDriver is the base class of all custom drivers.
This can also be used as the default driver of a component if no custom logic
is needed.
Attributes:
_metadata_handle: An instance of Metadata.
"""
def __init__(self, metadata_handle: metadata.Metadata):
self._metadata_handle = metadata_handle
def verify_input_artifacts(
self, artifacts_dict: Dict[str, List[types.Artifact]]) -> None:
"""Verify that all artifacts have existing uri.
Args:
artifacts_dict: key -> types.Artifact for inputs.
Raises:
RuntimeError: if any input as an empty or non-existing uri.
"""
artifact_utils.verify_artifacts(artifacts_dict)
def _log_properties(self, input_dict: Dict[str, List[types.Artifact]],
output_dict: Dict[str, List[types.Artifact]],
exec_properties: Dict[str, Any]):
"""Log inputs, outputs, and executor properties in a standard format."""
absl.logging.debug('Starting %s driver.', self.__class__.__name__)
absl.logging.debug('Inputs for %s are: %s', self.__class__.__name__,
input_dict)
absl.logging.debug('Execution properties for %s are: %s',
self.__class__.__name__, exec_properties)
absl.logging.debug('Outputs for %s are: %s', self.__class__.__name__,
output_dict)
def resolve_input_artifacts(
self,
input_dict: Dict[str, types.BaseChannel],
exec_properties: Dict[str, Any], # pylint: disable=unused-argument
driver_args: data_types.DriverArgs,
pipeline_info: data_types.PipelineInfo,
) -> Dict[str, List[types.Artifact]]:
"""Resolve input artifacts from metadata.
Subclasses might override this function for customized artifact properties
resolution logic. However please note that this function is supposed to be
called in normal cases (except head of the pipeline) since it handles
artifact info passing from upstream components.
Args:
input_dict: key -> Channel mapping for inputs generated in logical
pipeline.
exec_properties: Dict of other execution properties, e.g., configs.
driver_args: An instance of data_types.DriverArgs with driver
configuration properties.
pipeline_info: An instance of data_types.PipelineInfo, holding pipeline
related properties including component_type and component_id.
Returns:
Final artifacts that will be used in execution.
Raises:
ValueError: if in interactive mode, the given input channels have not been
resolved.
"""
result = {}
for name, value in input_dict.items():
artifacts_by_id = {} # Deduplicate by ID.
# TODO(b/248145891): Stop using get_individual_channels which does not
# support all BaseChannel types. Use a common input resolution stack
# instead.
for input_channel in channel_utils.get_individual_channels(value):
if driver_args.interactive_resolution:
artifacts = list(input_channel.get())
for artifact in artifacts:
# Note: when not initialized, artifact.uri is '' and artifact.id is
# 0.
if not artifact.uri or not artifact.id:
raise ValueError(
f'Unresolved input channel {repr(artifact)} for input '
f'{repr(name)} was passed in interactive mode. When running '
'in interactive mode, upstream components must first be run '
'with `interactive_context.run(component)` before their '
'outputs can be used in downstream components.')
artifacts_by_id.update({a.id: a for a in artifacts})
else:
artifacts = self._metadata_handle.search_artifacts(
artifact_name=input_channel.output_key,
pipeline_info=pipeline_info,
producer_component_id=input_channel.producer_component_id,
)
# TODO(ccy): add this code path to interactive resolution.
for artifact in artifacts:
if isinstance(artifact, types.ValueArtifact):
# Resolve the content of file into value field for value
# artifacts.
_ = artifact.read()
artifacts_by_id.update({a.id: a for a in artifacts})
result[name] = list(artifacts_by_id.values())
return result
def resolve_exec_properties(
self,
exec_properties: Dict[str, Any],
pipeline_info: data_types.PipelineInfo, # pylint: disable=unused-argument
component_info: data_types.ComponentInfo, # pylint: disable=unused-argument
) -> Dict[str, Any]:
"""Resolve execution properties.
Subclasses might override this function for customized execution properties
resolution logic.
Args:
exec_properties: Original execution properties passed in.
pipeline_info: An instance of data_types.PipelineInfo, holding pipeline
related properties including pipeline_name, pipeline_root and run_id
component_info: An instance of data_types.ComponentInfo, holding component
related properties including component_type and component_id.
Returns:
Final execution properties that will be used in execution.
"""
return exec_properties
def _prepare_output_artifacts(
self,
input_artifacts: Dict[str, List[types.Artifact]],
output_dict: Dict[str, types.Channel],
exec_properties: Dict[str, Any],
execution_id: int,
pipeline_info: data_types.PipelineInfo,
component_info: data_types.ComponentInfo,
) -> Dict[str, List[types.Artifact]]:
"""Prepare output artifacts by assigning uris to each artifact."""
del exec_properties
base_output_dir = os.path.join(pipeline_info.pipeline_root,
component_info.component_id)
result = {}
for name, channel in output_dict.items():
if channel.matching_channel_name:
# Decides the artifact count for output Channel at runtime based on the
# artifact count in specified input Channel.
count = len(input_artifacts[channel.matching_channel_name])
output_list = [channel.type() for _ in range(count)]
else:
output_list = [channel.type()]
is_single_artifact = len(output_list) == 1
for i, artifact in enumerate(output_list):
artifact.name = f'{name}:{pipeline_info.run_id}'
artifact.producer_component = component_info.component_id
artifact.uri = _generate_output_uri(base_output_dir, name, execution_id,
is_single_artifact, i)
# TODO(b/147242148): Introduce principled artifact structure (directory
# or file) definition.
if isinstance(artifact, types.ValueArtifact):
artifact.uri = os.path.join(artifact.uri, 'value')
_prepare_output_paths(artifact)
result[name] = output_list
return result
def pre_execution(
self,
input_dict: Dict[str, types.BaseChannel],
output_dict: Dict[str, types.Channel],
exec_properties: Dict[str, Any],
driver_args: data_types.DriverArgs,
pipeline_info: data_types.PipelineInfo,
component_info: data_types.ComponentInfo,
) -> data_types.ExecutionDecision:
"""Handle pre-execution logic.
There are four steps:
1. Fetches input artifacts from metadata and checks whether uri exists.
2. Registers execution.
3. Decides whether a new execution is needed.
4a. If (3), prepare output artifacts.
4b. If not (3), fetch cached output artifacts.
Args:
input_dict: key -> Channel for inputs.
output_dict: key -> Channel for outputs. Uris of the outputs are not
assigned.
exec_properties: Dict of other execution properties.
driver_args: An instance of data_types.DriverArgs class.
pipeline_info: An instance of data_types.PipelineInfo, holding pipeline
related properties including pipeline_name, pipeline_root and run_id
component_info: An instance of data_types.ComponentInfo, holding component
related properties including component_type and component_id.
Returns:
data_types.ExecutionDecision object.
Raises:
RuntimeError: if any input as an empty uri.
"""
# Step 1. Fetch inputs from metadata.
exec_properties = self.resolve_exec_properties(exec_properties,
pipeline_info,
component_info)
input_artifacts = self.resolve_input_artifacts(input_dict, exec_properties,
driver_args, pipeline_info)
self.verify_input_artifacts(artifacts_dict=input_artifacts)
absl.logging.debug('Resolved input artifacts are: %s', input_artifacts)
# Step 2. Register execution in metadata.
contexts = self._metadata_handle.register_pipeline_contexts_if_not_exists(
pipeline_info
)
execution = self._metadata_handle.register_execution(
input_artifacts=input_artifacts,
exec_properties=exec_properties,
pipeline_info=pipeline_info,
component_info=component_info,
contexts=contexts,
)
use_cached_results = False
output_artifacts = None
if driver_args.enable_cache:
# Step 3. Decide whether a new execution is needed.
output_artifacts = self._metadata_handle.get_cached_outputs(
input_artifacts=input_artifacts,
exec_properties=exec_properties,
pipeline_info=pipeline_info,
component_info=component_info,
)
# Check that cached output artifacts will actually be considered a cache
# hit by downstream components
if output_artifacts is not None:
try:
artifact_utils.verify_artifacts(output_artifacts)
use_cached_results = True
except RuntimeError:
absl.logging.debug(
'Cached results found but could not be verified to still exist')
if use_cached_results:
# If cache should be used, updates execution to reflect that. Note that
# with this update, publisher should / will be skipped.
self._metadata_handle.update_execution(
execution=execution,
component_info=component_info,
output_artifacts=output_artifacts,
execution_state=metadata.EXECUTION_STATE_CACHED,
contexts=contexts,
)
else:
absl.logging.debug(
'Cached results not available, move on to new execution')
# Step 4a. New execution is needed. Prepare output artifacts.
output_artifacts = self._prepare_output_artifacts(
input_artifacts=input_artifacts,
output_dict=output_dict,
exec_properties=exec_properties,
execution_id=execution.id,
pipeline_info=pipeline_info,
component_info=component_info)
absl.logging.debug(
'Output artifacts skeleton for the upcoming execution are: %s',
output_artifacts)
# Updates the execution to reflect refreshed output artifacts and
# execution properties.
self._metadata_handle.update_execution(
execution=execution,
component_info=component_info,
output_artifacts=output_artifacts,
exec_properties=exec_properties,
contexts=contexts,
)
absl.logging.debug(
'Execution properties for the upcoming execution are: %s',
exec_properties)
# For interactive execution, update the output channel contents.
# TODO(b/161490287): figure out the long-term behavior of Channel artifacts
# with respect to interactive and non-interactive execution.
if driver_args.interactive_resolution:
for key, artifact_list in output_artifacts.items():
channel = output_dict[key]
channel._artifacts = artifact_list # pylint: disable=protected-access
return data_types.ExecutionDecision(input_artifacts, output_artifacts,
exec_properties, execution.id,
use_cached_results)