From 5b93b5d9ae921daf0194d4a9f514f358dc13f5a5 Mon Sep 17 00:00:00 2001 From: tfx-team Date: Fri, 12 Jul 2024 09:21:32 -0700 Subject: [PATCH] Orchestrator shouldn't crash when MLMD call fails PiperOrigin-RevId: 651796197 --- tfx/orchestration/experimental/core/env.py | 19 +++++ .../experimental/core/env_test.py | 5 ++ .../experimental/core/pipeline_ops.py | 19 ++++- .../experimental/core/pipeline_ops_test.py | 78 +++++++++++++++++++ .../experimental/core/test_utils.py | 12 +++ tfx/utils/status.py | 20 +++++ 6 files changed, 150 insertions(+), 3 deletions(-) diff --git a/tfx/orchestration/experimental/core/env.py b/tfx/orchestration/experimental/core/env.py index 37ba79a889..5ec0496cd6 100644 --- a/tfx/orchestration/experimental/core/env.py +++ b/tfx/orchestration/experimental/core/env.py @@ -137,6 +137,20 @@ def should_orchestrate(self, pipeline: pipeline_pb2.Pipeline) -> bool: Whether the env should orchestrate the pipeline. """ + @abc.abstractmethod + def get_status_code_from_exception( + self, exception: Optional[BaseException] + ) -> Optional[int]: + """Returns the status code from the given exception. + + Args: + exception: An exception. + + Returns: + Code of the exception. + Returns None if the exception is not a known type. + """ + class _DefaultEnv(Env): """Default environment.""" @@ -211,6 +225,11 @@ def should_orchestrate(self, pipeline: pipeline_pb2.Pipeline) -> bool: # By default, all pipeline runs should be orchestrated. return True + def get_status_code_from_exception( + self, exception: Optional[BaseException] + ) -> Optional[int]: + return None + _ENV = _DefaultEnv() diff --git a/tfx/orchestration/experimental/core/env_test.py b/tfx/orchestration/experimental/core/env_test.py index a5f5e3e605..4cd0b721c8 100644 --- a/tfx/orchestration/experimental/core/env_test.py +++ b/tfx/orchestration/experimental/core/env_test.py @@ -60,6 +60,11 @@ def prepare_orchestrator_for_pipeline_run( ): raise NotImplementedError() + def get_status_code_from_exception( + self, exception: Optional[BaseException] + ) -> Optional[int]: + raise NotImplementedError() + def create_sync_or_upsert_async_pipeline_run( self, owner: str, diff --git a/tfx/orchestration/experimental/core/pipeline_ops.py b/tfx/orchestration/experimental/core/pipeline_ops.py index 76188665f9..452a93ed62 100644 --- a/tfx/orchestration/experimental/core/pipeline_ops.py +++ b/tfx/orchestration/experimental/core/pipeline_ops.py @@ -1298,9 +1298,22 @@ def orchestrate( if filter_fn is None: filter_fn = lambda _: True - all_pipeline_states = pstate.PipelineState.load_all_active_and_owned( - mlmd_connection_manager.primary_mlmd_handle - ) + # Try to load active pipelines. If there is a recoverable error, return True + # and then retry in the next orchestration iteration. + try: + all_pipeline_states = pstate.PipelineState.load_all_active_and_owned( + mlmd_connection_manager.primary_mlmd_handle + ) + except Exception as e: # pylint: disable=broad-except + code = env.get_env().get_status_code_from_exception(e) + if code in status_lib.BATCH_RETRIABLE_ERROR_CODES: + logging.exception( + 'Failed to load active pipeline states. Will retry in next' + ' orchestration iteration.', + ) + return True + raise e + pipeline_states = [s for s in all_pipeline_states if filter_fn(s)] if not pipeline_states: logging.info('No active pipelines to run.') diff --git a/tfx/orchestration/experimental/core/pipeline_ops_test.py b/tfx/orchestration/experimental/core/pipeline_ops_test.py index 56bb115187..376a99d219 100644 --- a/tfx/orchestration/experimental/core/pipeline_ops_test.py +++ b/tfx/orchestration/experimental/core/pipeline_ops_test.py @@ -54,6 +54,7 @@ from tfx.types import standard_artifacts from tfx.utils import status as status_lib +from ml_metadata import errors as mlmd_errors from ml_metadata.proto import metadata_store_pb2 @@ -3589,6 +3590,83 @@ def test_orchestrate_pipelines_with_specified_pipeline_uid( ) self.assertTrue(task_queue.is_empty()) + @parameterized.parameters( + (mlmd_errors.DeadlineExceededError('DeadlineExceededError'), 4), + (mlmd_errors.InternalError('InternalError'), 13), + (mlmd_errors.UnavailableError('UnavailableError'), 14), + (mlmd_errors.ResourceExhaustedError('ResourceExhaustedError'), 8), + ( + status_lib.StatusNotOkError( + code=status_lib.Code.DEADLINE_EXCEEDED, + message='DeadlineExceededError', + ), + 4, + ), + ( + status_lib.StatusNotOkError( + code=status_lib.Code.INTERNAL, message='InternalError' + ), + 13, + ), + ( + status_lib.StatusNotOkError( + code=status_lib.Code.UNAVAILABLE, message='UnavailableError' + ), + 14, + ), + ( + status_lib.StatusNotOkError( + code=status_lib.Code.RESOURCE_EXHAUSTED, + message='ResourceExhaustedError', + ), + 8, + ), + ) + @mock.patch.object(pstate.PipelineState, 'load_all_active_and_owned') + def test_orchestrate_pipelines_with_recoverable_error_from_MLMD( + self, + error, + error_code, + mock_load_all_active_and_owned, + ): + mock_load_all_active_and_owned.side_effect = error + + with test_utils.get_status_code_from_exception_environment(error_code): + with self._mlmd_cm as mlmd_connection_manager: + task_queue = tq.TaskQueue() + orchestrate_result = pipeline_ops.orchestrate( + mlmd_connection_manager, + task_queue, + service_jobs.DummyServiceJobManager(), + ) + self.assertEqual(orchestrate_result, True) + + @parameterized.parameters( + mlmd_errors.InvalidArgumentError('InvalidArgumentError'), + mlmd_errors.FailedPreconditionError('FailedPreconditionError'), + status_lib.StatusNotOkError( + code=status_lib.Code.INVALID_ARGUMENT, message='InvalidArgumentError' + ), + status_lib.StatusNotOkError( + code=status_lib.Code.UNKNOWN, + message='UNKNOWN', + ), + ) + @mock.patch.object(pstate.PipelineState, 'load_all_active_and_owned') + def test_orchestrate_pipelines_with_not_recoverable_error_from_MLMD( + self, error, mock_load_all_active_and_owned + ): + mock_load_all_active_and_owned.side_effect = error + + with self._mlmd_cm as mlmd_connection_manager: + task_queue = tq.TaskQueue() + with self.assertRaises(Exception): + pipeline_ops.orchestrate( + mlmd_connection_manager, + task_queue, + service_jobs.DummyServiceJobManager(), + ) + if __name__ == '__main__': tf.test.main() diff --git a/tfx/orchestration/experimental/core/test_utils.py b/tfx/orchestration/experimental/core/test_utils.py index 563a0fa1e2..e5d0377460 100644 --- a/tfx/orchestration/experimental/core/test_utils.py +++ b/tfx/orchestration/experimental/core/test_utils.py @@ -511,3 +511,15 @@ def prepare_orchestrator_for_pipeline_run( pipeline.sdk_version = 'postprocessed' return _TestEnv() + + +def get_status_code_from_exception_environment(error_code: int): + + class _TestEnv(env._DefaultEnv): # pylint: disable=protected-access + + def get_status_code_from_exception( + self, exception: Optional[BaseException] + ) -> Optional[int]: + return error_code + + return _TestEnv() diff --git a/tfx/utils/status.py b/tfx/utils/status.py index b7da889439..1a546c73d5 100644 --- a/tfx/utils/status.py +++ b/tfx/utils/status.py @@ -49,6 +49,26 @@ class Code(enum.IntEnum): UNAUTHENTICATED = 16 +# These are the error codes that are retriable for USER_FACING traffic. +# See go/stubs-retries. +USER_FACING_RETRIABLE_STATUS_CODES = frozenset( + c.value + for c in [ + Code.UNAVAILABLE, + ] +) + +BATCH_RETRIABLE_ERROR_CODES = frozenset( + c.value + for c in [ + Code.DEADLINE_EXCEEDED, + Code.INTERNAL, + Code.UNAVAILABLE, + Code.RESOURCE_EXHAUSTED, + ] +) + + @attr.s(auto_attribs=True, frozen=True) class Status: """Class to record status of operations.