Skip to content

Commit

Permalink
Create the component_generated_alert proto, add `ComponentGenerated…
Browse files Browse the repository at this point in the history
…Alert` event to `event_observer`, and add functionality to notify the `ComponentGeneratedAlert` event in `post_execution_utils`.

PiperOrigin-RevId: 574574472
  • Loading branch information
tfx-copybara committed Oct 18, 2023
1 parent ceea6bd commit 13c3e1b
Show file tree
Hide file tree
Showing 9 changed files with 166 additions and 23 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright 2023 Google LLC. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// Messages for configuring component generated alerts.

syntax = "proto2";

package tfx.orchestration.experimental.core;

message ComponentGeneratedAlertInfo {
optional string alert_name = 1;
optional string alert_body = 2;
}

message ComponentGeneratedAlertList {
repeated ComponentGeneratedAlertInfo component_generated_alert_list = 1;
}
12 changes: 11 additions & 1 deletion tfx/orchestration/experimental/core/event_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,18 @@ class NodeStateChange:
new_state: Any


@dataclasses.dataclass(frozen=True)
class ComponentGeneratedAlert:
"""ComponentGeneratedAlert event."""
execution: metadata_store_pb2.Execution
pipeline_uid: task_lib.PipelineUid
node_id: str
alert_name: str
alert_body: str


Event = Union[PipelineStarted, PipelineFinished, NodeStateChange,
ExecutionStateChange]
ExecutionStateChange, ComponentGeneratedAlert]

ObserverFn = Callable[[Event], None]

Expand Down
27 changes: 25 additions & 2 deletions tfx/orchestration/experimental/core/post_execution_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from tfx.dsl.io import fileio
from tfx.orchestration import data_types_utils
from tfx.orchestration import metadata
from tfx.orchestration.experimental.core import component_generated_alert_pb2
from tfx.orchestration.experimental.core import constants
from tfx.orchestration.experimental.core import event_observer
from tfx.orchestration.experimental.core import garbage_collection
Expand All @@ -36,6 +37,9 @@
from ml_metadata import proto


_COMPONENT_GENERATED_ALERTS_KEY = '__component_generated_alerts__'


def publish_execution_results_for_task(mlmd_handle: metadata.Metadata,
task: task_lib.ExecNodeTask,
result: ts.TaskSchedulerResult) -> None:
Expand Down Expand Up @@ -87,7 +91,7 @@ def _update_state(
# TODO(b/262040844): Instead of directly using the context manager here, we
# should consider creating and using wrapper functions.
with mlmd_state.evict_from_cache(task.execution_id):
execution_publish_utils.publish_succeeded_execution(
_, execution = execution_publish_utils.publish_succeeded_execution(
mlmd_handle,
execution_id=task.execution_id,
contexts=task.contexts,
Expand All @@ -96,6 +100,24 @@ def _update_state(
garbage_collection.run_garbage_collection_for_node(mlmd_handle,
task.node_uid,
task.get_node())
if _COMPONENT_GENERATED_ALERTS_KEY in execution.custom_properties:
alerts_proto = component_generated_alert_pb2.ComponentGeneratedAlertList()
execution.custom_properties[
_COMPONENT_GENERATED_ALERTS_KEY
].proto_value.Unpack(alerts_proto)

for alert in alerts_proto.component_generated_alert_list:
alert_event = event_observer.ComponentGeneratedAlert(
execution=execution,
pipeline_uid=task_lib.PipelineUid(
pipeline_id=task.pipeline.pipeline_info.id
),
node_id=task.node_uid.node_id,
alert_body=alert.alert_body,
alert_name=alert.alert_name,
)
event_observer.notify(alert_event)

elif isinstance(result.output, ts.ImporterNodeOutput):
output_artifacts = result.output.output_artifacts
_remove_temporary_task_dirs(
Expand Down Expand Up @@ -156,12 +178,13 @@ def publish_execution_results(
# TODO(b/262040844): Instead of directly using the context manager here, we
# should consider creating and using wrapper functions.
with mlmd_state.evict_from_cache(execution_info.execution_id):
return execution_publish_utils.publish_succeeded_execution(
output_dict, _ = execution_publish_utils.publish_succeeded_execution(
mlmd_handle,
execution_id=execution_info.execution_id,
contexts=contexts,
output_artifacts=execution_info.output_dict,
executor_output=executor_output)
return output_dict


def _update_execution_state_in_mlmd(
Expand Down
65 changes: 65 additions & 0 deletions tfx/orchestration/experimental/core/post_execution_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,17 @@
from tfx.dsl.io import fileio
from tfx.orchestration import data_types_utils
from tfx.orchestration import metadata
from tfx.orchestration.experimental.core import component_generated_alert_pb2
from tfx.orchestration.experimental.core import event_observer
from tfx.orchestration.experimental.core import post_execution_utils
from tfx.orchestration.experimental.core import task as task_lib
from tfx.orchestration.experimental.core import task_scheduler as ts
from tfx.orchestration.experimental.core import test_utils
from tfx.orchestration.portable import data_types
from tfx.orchestration.portable import execution_publish_utils
from tfx.proto.orchestration import execution_invocation_pb2
from tfx.proto.orchestration import execution_result_pb2
from tfx.proto.orchestration import pipeline_pb2
from tfx.types import standard_artifacts
from tfx.utils import status as status_lib
from tfx.utils import test_case_utils as tu
Expand Down Expand Up @@ -102,6 +108,8 @@ def test_publish_execution_results_succeeded_execution(self, mock_publish):
executor_output = execution_result_pb2.ExecutorOutput()
executor_output.execution_result.code = 0

mock_publish.return_value = [None, None]

post_execution_utils.publish_execution_results(
self.mlmd_handle, executor_output, execution_info, contexts=[])

Expand All @@ -113,6 +121,63 @@ def test_publish_execution_results_succeeded_execution(self, mock_publish):
output_artifacts=execution_info.output_dict,
executor_output=executor_output)

@mock.patch.object(event_observer, 'notify')
def test_publish_execution_results_for_task_with_alerts(self, mock_notify):
_ = self._prepare_execution_info()

executor_output = execution_result_pb2.ExecutorOutput()
executor_output.execution_result.code = 0

component_generated_alerts = (
component_generated_alert_pb2.ComponentGeneratedAlertList()
)
component_generated_alerts.component_generated_alert_list.append(
component_generated_alert_pb2.ComponentGeneratedAlertInfo(
alert_name='test_alert',
alert_body='test_alert_body',
)
)
executor_output.execution_properties[
post_execution_utils._COMPONENT_GENERATED_ALERTS_KEY
].proto_value.Pack(component_generated_alerts)

[execution] = self.mlmd_handle.store.get_executions()

# Create test pipeline.
deployment_config = pipeline_pb2.IntermediateDeploymentConfig()
executor_spec = pipeline_pb2.ExecutorSpec.PythonClassExecutorSpec(
class_path='trainer.TrainerExecutor')
deployment_config.executor_specs['AlertGenerator'].Pack(
executor_spec
)
pipeline = pipeline_pb2.Pipeline()
pipeline.nodes.add().pipeline_node.node_info.id = 'AlertGenerator'
pipeline.pipeline_info.id = 'test-pipeline'
pipeline.deployment_config.Pack(deployment_config)

node_uid = task_lib.NodeUid(
pipeline_uid=task_lib.PipelineUid(
pipeline_id=pipeline.pipeline_info.id
),
node_id='AlertGenerator',
)
task = test_utils.create_exec_node_task(
node_uid=node_uid,
execution=execution,
pipeline=pipeline,
)
result = ts.TaskSchedulerResult(
status=status_lib.Status(
code=status_lib.Code.OK,
message='test TaskScheduler result'
),
output=ts.ExecutorNodeOutput(executor_output=executor_output)
)
post_execution_utils.publish_execution_results_for_task(
self.mlmd_handle, task, result
)
mock_notify.assert_called_once()


if __name__ == '__main__':
tf.test.main()
6 changes: 4 additions & 2 deletions tfx/orchestration/experimental/core/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,10 @@ def fake_finish_node_with_handle(
else:
output_artifacts = None
contexts = context_lib.prepare_contexts(mlmd_handle, node.contexts)
return execution_publish_utils.publish_succeeded_execution(
mlmd_handle, execution_id, contexts, output_artifacts)
output_dict, _ = execution_publish_utils.publish_succeeded_execution(
mlmd_handle, execution_id, contexts, output_artifacts
)
return output_dict


def create_exec_node_task(
Expand Down
2 changes: 1 addition & 1 deletion tfx/orchestration/portable/cache_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def testGetCachedOutputArtifacts(self, mock_verify_artifacts):
})
execution_two = execution_publish_utils.register_execution(
m, metadata_store_pb2.ExecutionType(name='my_type'), [cache_context])
output_artifacts = execution_publish_utils.publish_succeeded_execution(
output_artifacts, _ = execution_publish_utils.publish_succeeded_execution(
m,
execution_two.id, [cache_context],
output_artifacts={
Expand Down
16 changes: 10 additions & 6 deletions tfx/orchestration/portable/execution_publish_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,10 @@ def publish_succeeded_execution(
contexts: Sequence[metadata_store_pb2.Context],
output_artifacts: Optional[typing_utils.ArtifactMultiMap] = None,
executor_output: Optional[execution_result_pb2.ExecutorOutput] = None,
) -> Optional[typing_utils.ArtifactMultiMap]:
) -> tuple[
Optional[typing_utils.ArtifactMultiMap],
metadata_store_pb2.Execution,
]:
"""Marks an existing execution as success.
Also publishes the output artifacts produced by the execution. This method
Expand All @@ -102,9 +105,10 @@ def publish_succeeded_execution(
artifact should not change the type of the artifact.
Returns:
The maybe updated output_artifacts, note that only outputs whose key are in
executor_output will be updated and others will be untouched. That said,
it can be partially updated.
The tuple containing the maybe updated output_artifacts (note that only
outputs whose key are in executor_output will be updated and others will be
untouched, that said, it can be partially updated) and the written
execution.
Raises:
RuntimeError: if the executor output to a output channel is partial.
"""
Expand Down Expand Up @@ -147,14 +151,14 @@ def publish_succeeded_execution(
execution.custom_properties[key].CopyFrom(value)
set_execution_result_if_not_empty(executor_output, execution)

execution_lib.put_execution(
execution = execution_lib.put_execution(
metadata_handle,
execution,
contexts,
output_artifacts=output_artifacts_to_publish,
)

return output_artifacts_to_publish
return output_artifacts_to_publish, execution


def publish_failed_execution(
Expand Down
30 changes: 20 additions & 10 deletions tfx/orchestration/portable/execution_publish_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,15 @@ def testPublishSuccessfulExecution(self):
value {int_value: 1}
}
""", executor_output.output_artifacts[output_key].artifacts.add())
output_dict = execution_publish_utils.publish_succeeded_execution(
m, execution_id, contexts, {output_key: [output_example]},
executor_output)
[execution] = m.store.get_executions()
output_dict, execution = (
execution_publish_utils.publish_succeeded_execution(
m,
execution_id,
contexts,
{output_key: [output_example]},
executor_output,
)
)
self.assertProtoPartiallyEquals(
"""
id: 1
Expand Down Expand Up @@ -303,7 +308,7 @@ def testPublishSuccessfulExecutionWithRuntimeResolvedUri(self):
}}
""", executor_output.output_artifacts[output_key].artifacts.add())

output_dict = execution_publish_utils.publish_succeeded_execution(
output_dict, _ = execution_publish_utils.publish_succeeded_execution(
m, execution_id, contexts, {output_key: [output_example]},
executor_output)
self.assertLen(output_dict[output_key], 2)
Expand Down Expand Up @@ -361,7 +366,7 @@ def testPublishSuccessfulExecutionOmitsArtifactIfNotResolvedDuringRuntime(
value {{int_value: 1}}
}}
""", executor_output.output_artifacts['key1'].artifacts.add())
output_dict = execution_publish_utils.publish_succeeded_execution(
output_dict, _ = execution_publish_utils.publish_succeeded_execution(
m, execution_id, contexts, original_artifacts, executor_output)
self.assertEmpty(output_dict['key1'])
self.assertNotEmpty(output_dict['key2'])
Expand Down Expand Up @@ -414,10 +419,15 @@ def testPublishSuccessExecutionExecutorEditedOutputDict(self):
}
""", executor_output.output_artifacts[output_key].artifacts.add())

output_dict = execution_publish_utils.publish_succeeded_execution(
m, execution_id, contexts, {output_key: [output_example]},
executor_output)
[execution] = m.store.get_executions()
output_dict, execution = (
execution_publish_utils.publish_succeeded_execution(
m,
execution_id,
contexts,
{output_key: [output_example]},
executor_output,
)
)
self.assertProtoPartiallyEquals(
"""
id: 1
Expand Down
3 changes: 2 additions & 1 deletion tfx/orchestration/portable/inputs_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,10 @@ def fake_execute(self, metadata_handle, pipeline_node, input_map, output_map):
execution = execution_publish_utils.register_execution(
metadata_handle, pipeline_node.node_info.type, contexts, input_map
)
return execution_publish_utils.publish_succeeded_execution(
output_dict, _ = execution_publish_utils.publish_succeeded_execution(
metadata_handle, execution.id, contexts, output_map
)
return output_dict

def assertArtifactEqual(self, expected, actual):
self.assertProtoPartiallyEquals(
Expand Down

0 comments on commit 13c3e1b

Please sign in to comment.