Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Skip any non live artifacts and any non final executions in base driver's input resolution. #4093

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 0 additions & 1 deletion RELEASE.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Current Version (Still in Development)

## Major Features and Improvements
* Added RuntimeParam support for Trainer's custom_config.

## Breaking Changes

Expand Down
7 changes: 2 additions & 5 deletions tfx/components/trainer/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,7 @@ def __init__(
data_types.RuntimeParameter]] = None,
eval_args: Optional[Union[trainer_pb2.EvalArgs,
data_types.RuntimeParameter]] = None,
custom_config: Optional[Union[Dict[Text, Any],
data_types.RuntimeParameter]] = None,
custom_config: Optional[Dict[Text, Any]] = None,
custom_executor_spec: Optional[executor_spec.ExecutorSpec] = None):
"""Construct a Trainer component.

Expand Down Expand Up @@ -196,9 +195,7 @@ def trainer_fn(trainer.fn_args_utils.FnArgs,
module_file=module_file,
run_fn=run_fn,
trainer_fn=trainer_fn,
custom_config=(custom_config
if isinstance(custom_config, data_types.RuntimeParameter)
else json_utils.dumps(custom_config)),
custom_config=json_utils.dumps(custom_config),
model=model,
model_run=model_run)
super(Trainer, self).__init__(
Expand Down
32 changes: 1 addition & 31 deletions tfx/components/trainer/component_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,11 @@ def testConstructFromModuleFile(self):
module_file=module_file,
examples=self.examples,
transform_graph=self.transform_graph,
schema=self.schema,
custom_config={'test': 10})
schema=self.schema)
self._verify_outputs(trainer)
self.assertEqual(
module_file,
trainer.spec.exec_properties[standard_component_specs.MODULE_FILE_KEY])
self.assertEqual(
'{"test": 10}', trainer.spec.exec_properties[
standard_component_specs.CUSTOM_CONFIG_KEY])

def testConstructWithParameter(self):
module_file = data_types.RuntimeParameter(name='module-file', ptype=Text)
Expand Down Expand Up @@ -188,32 +184,6 @@ def testConstructWithHParams(self):
standard_artifacts.HyperParameters.TYPE_NAME,
trainer.inputs[standard_component_specs.HYPERPARAMETERS_KEY].type_name)

def testConstructWithRuntimeParam(self):
eval_args = data_types.RuntimeParameter(
name='eval-args',
default='{"num_steps": 50}',
ptype=Text,
)
custom_config = data_types.RuntimeParameter(
name='custom-config',
default='{"test": 10}',
ptype=Text,
)
trainer = component.Trainer(
trainer_fn='path.to.my_trainer_fn',
examples=self.examples,
train_args=self.train_args,
eval_args=eval_args,
custom_config=custom_config)
self._verify_outputs(trainer)
self.assertIsInstance(
trainer.spec.exec_properties[standard_component_specs.EVAL_ARGS_KEY],
data_types.RuntimeParameter)
self.assertIsInstance(
trainer.spec.exec_properties[
standard_component_specs.CUSTOM_CONFIG_KEY],
data_types.RuntimeParameter)


if __name__ == '__main__':
tf.test.main()
26 changes: 22 additions & 4 deletions tfx/orchestration/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,17 @@
EXECUTION_STATE_NEW = 'new'
FINAL_EXECUTION_STATES = frozenset(
(EXECUTION_STATE_CACHED, EXECUTION_STATE_COMPLETE))


def _is_execution_final(execution: metadata_store_pb2.Execution) -> bool:
return (execution.properties[_EXECUTION_TYPE_KEY_STATE].string_value
in FINAL_EXECUTION_STATES)


def _is_artifact_live(artifact: metadata_store_pb2.Artifact) -> bool:
return artifact.mlmd_artifact.state == metadata_store_pb2.Artifact.LIVE


# Context type, the following three types of contexts are supported:
# - pipeline level context is shared within one pipeline, across multiple
# pipeline runs.
Expand Down Expand Up @@ -98,8 +109,7 @@ def sqlite_metadata_connection_config(
fileio.makedirs(os.path.dirname(metadata_db_uri))
connection_config = metadata_store_pb2.ConnectionConfig()
connection_config.sqlite.filename_uri = metadata_db_uri
connection_config.sqlite.connection_mode = \
metadata_store_pb2.SqliteMetadataSourceConfig.READWRITE_OPENCREATE
connection_config.sqlite.connection_mode = metadata_store_pb2.SqliteMetadataSourceConfig.READWRITE_OPENCREATE
return connection_config


Expand Down Expand Up @@ -739,8 +749,7 @@ def publish_execution(
]
contexts = [ctx for ctx in contexts if ctx is not None]
# If execution state is already in final state, skips publishing.
if execution.properties[
_EXECUTION_TYPE_KEY_STATE].string_value in FINAL_EXECUTION_STATES:
if _is_execution_final(execution):
return
self.update_execution(
execution=execution,
Expand Down Expand Up @@ -945,6 +954,12 @@ def search_artifacts(self, artifact_name: Text,
if context is None:
raise RuntimeError('Pipeline run context for %s does not exist' %
pipeline_info)
exec_by_context = self.store.get_executions_by_context(context.id)
exec_in_final_state = list(filter(_is_execution_final))
if len(exec_in_final_state) < exec_by_context:
absl.logging.info(
'Skipped some non final executions in given context %d: %s',
context.id, exec_by_context)
for execution in self.store.get_executions_by_context(context.id):
if execution.properties[
'component_id'].string_value == producer_component_id:
Expand Down Expand Up @@ -974,6 +989,9 @@ def search_artifacts(self, artifact_name: Text,
for a in artifacts_by_id:
tfx_artifact = artifact_utils.deserialize_artifact(
artifact_types[a.type_id], a)
if not _is_artifact_live(tfx_artifact):
absl.logging.warn('Skipping non live artifact %s in context (id=%d)',
tfx_artifact, context.id)
result_artifacts.append(tfx_artifact)
return result_artifacts

Expand Down