Skip to content

Commit

Permalink
Orchestrator shouldn't crash when MLMD call fails
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 641385227
  • Loading branch information
tfx-copybara committed Jun 28, 2024
1 parent 4e71a35 commit 88d7f18
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 3 deletions.
12 changes: 9 additions & 3 deletions tfx/orchestration/experimental/core/pipeline_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1298,9 +1298,15 @@ def orchestrate(
if filter_fn is None:
filter_fn = lambda _: True

all_pipeline_states = pstate.PipelineState.load_all_active_and_owned(
mlmd_connection_manager.primary_mlmd_handle
)
# Try to load active pipelines. If there is a recoverable error, return False
# and then retry in the next orchestration iteration.
try:
all_pipeline_states = pstate.PipelineState.load_all_active_and_owned(
mlmd_connection_manager.primary_mlmd_handle
)
except Exception as e: # pylint: disable=broad-except
raise e

pipeline_states = [s for s in all_pipeline_states if filter_fn(s)]
if not pipeline_states:
logging.info('No active pipelines to run.')
Expand Down
61 changes: 61 additions & 0 deletions tfx/orchestration/experimental/core/pipeline_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from tfx.types import standard_artifacts
from tfx.utils import status as status_lib

from ml_metadata import errors as mlmd_errors
from ml_metadata.proto import metadata_store_pb2


Expand Down Expand Up @@ -3589,6 +3590,66 @@ def test_orchestrate_pipelines_with_specified_pipeline_uid(
)
self.assertTrue(task_queue.is_empty())

@parameterized.parameters(
mlmd_errors.DeadlineExceededError('DeadlineExceededError'),
mlmd_errors.InternalError('InternalError'),
mlmd_errors.UnavailableError('UnavailableError'),
mlmd_errors.ResourceExhaustedError('ResourceExhaustedError'),
status_lib.StatusNotOkError(
code=status_lib.Code.DEADLINE_EXCEEDED,
message='DeadlineExceededError',
),
status_lib.StatusNotOkError(
code=status_lib.Code.INTERNAL, message='InternalError'
),
status_lib.StatusNotOkError(
code=status_lib.Code.UNAVAILABLE, message='UnavailableError'
),
status_lib.StatusNotOkError(
code=status_lib.Code.RESOURCE_EXHAUSTED,
message='ResourceExhaustedError',
),
)
@mock.patch.object(pstate.PipelineState, 'load_all_active_and_owned')
def test_orchestrate_pipelines_with_recoverable_error_from_MLMD(
self, error, mock_load_all_active_and_owned
):
mock_load_all_active_and_owned.side_effect = error

with self._mlmd_cm as mlmd_connection_manager:
task_queue = tq.TaskQueue()
orchestrate_result = pipeline_ops.orchestrate(
mlmd_connection_manager,
task_queue,
service_jobs.DummyServiceJobManager(),
)
self.assertEqual(orchestrate_result, False)

@parameterized.parameters(
mlmd_errors.InvalidArgumentError('InvalidArgumentError'),
mlmd_errors.FailedPreconditionError('FailedPreconditionError'),
status_lib.StatusNotOkError(
code=status_lib.Code.INVALID_ARGUMENT, message='InvalidArgumentError'
),
status_lib.StatusNotOkError(
code=status_lib.Code.UNKNOWN,
message='UNKNOWN',
),
)
@mock.patch.object(pstate.PipelineState, 'load_all_active_and_owned')
def test_orchestrate_pipelines_with_not_recoverable_error_from_MLMD(
self, error, mock_load_all_active_and_owned
):
mock_load_all_active_and_owned.side_effect = error

with self._mlmd_cm as mlmd_connection_manager:
task_queue = tq.TaskQueue()
with self.assertRaises(Exception):
pipeline_ops.orchestrate(
mlmd_connection_manager,
task_queue,
service_jobs.DummyServiceJobManager(),
)

if __name__ == '__main__':
tf.test.main()

0 comments on commit 88d7f18

Please sign in to comment.