Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Orchestrator shouldn't crash when MLMD call fails #6828

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions tfx/components/statistics_gen/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from absl import logging
import tensorflow_data_validation as tfdv
from tensorflow_data_validation.statistics import stats_options as options
from tensorflow_data_validation.utils import dashboard_util
from tfx import types
from tfx.components.statistics_gen import stats_artifact_utils
from tfx.components.util import examples_utils
Expand All @@ -27,7 +28,6 @@
from tfx.types import standard_component_specs
from tfx.utils import io_utils
from tfx.utils import json_utils
from tfx.utils import stats_utils


# Default file name for stats generated.
Expand Down Expand Up @@ -151,8 +151,7 @@ def Do(

try:
statistics_artifact.set_string_custom_property(
STATS_DASHBOARD_LINK,
stats_utils.generate_stats_dashboard_link(statistics_artifact),
STATS_DASHBOARD_LINK, dashboard_util.generate_stats_dashboard_link()
)
except Exception as e: # pylint: disable=broad-except
# log on failures to not bring down Statsgen jobs
Expand Down
19 changes: 19 additions & 0 deletions tfx/orchestration/experimental/core/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,20 @@ def should_orchestrate(self, pipeline: pipeline_pb2.Pipeline) -> bool:
Whether the env should orchestrate the pipeline.
"""

@abc.abstractmethod
def get_status_code_from_exception(
self, exception: Optional[BaseException]
) -> Optional[int]:
"""Returns the status code from the given exception.

Args:
exception: An exception.

Returns:
Code of the exception.
Returns None if the exception is not a known type.
"""


class _DefaultEnv(Env):
"""Default environment."""
Expand Down Expand Up @@ -211,6 +225,11 @@ def should_orchestrate(self, pipeline: pipeline_pb2.Pipeline) -> bool:
# By default, all pipeline runs should be orchestrated.
return True

def get_status_code_from_exception(
self, exception: Optional[BaseException]
) -> Optional[int]:
return None


_ENV = _DefaultEnv()

Expand Down
5 changes: 5 additions & 0 deletions tfx/orchestration/experimental/core/env_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ def prepare_orchestrator_for_pipeline_run(
):
raise NotImplementedError()

def get_status_code_from_exception(
self, exception: Optional[BaseException]
) -> Optional[int]:
raise NotImplementedError()

def create_sync_or_upsert_async_pipeline_run(
self,
owner: str,
Expand Down
19 changes: 16 additions & 3 deletions tfx/orchestration/experimental/core/pipeline_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1298,9 +1298,22 @@ def orchestrate(
if filter_fn is None:
filter_fn = lambda _: True

all_pipeline_states = pstate.PipelineState.load_all_active_and_owned(
mlmd_connection_manager.primary_mlmd_handle
)
# Try to load active pipelines. If there is a recoverable error, return False
# and then retry in the next orchestration iteration.
try:
all_pipeline_states = pstate.PipelineState.load_all_active_and_owned(
mlmd_connection_manager.primary_mlmd_handle
)
except Exception as e: # pylint: disable=broad-except
code = env.get_env().get_status_code_from_exception(e)
if code in status_lib.BATCH_RETRIABLE_ERROR_CODES:
logging.exception(
'Failed to load active pipeline states. Will retry in next'
' orchestration iteration.',
)
return True
raise e

pipeline_states = [s for s in all_pipeline_states if filter_fn(s)]
if not pipeline_states:
logging.info('No active pipelines to run.')
Expand Down
78 changes: 78 additions & 0 deletions tfx/orchestration/experimental/core/pipeline_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from tfx.types import standard_artifacts
from tfx.utils import status as status_lib

from ml_metadata import errors as mlmd_errors
from ml_metadata.proto import metadata_store_pb2


Expand Down Expand Up @@ -3589,6 +3590,83 @@ def test_orchestrate_pipelines_with_specified_pipeline_uid(
)
self.assertTrue(task_queue.is_empty())

@parameterized.parameters(
(mlmd_errors.DeadlineExceededError('DeadlineExceededError'), 4),
(mlmd_errors.InternalError('InternalError'), 13),
(mlmd_errors.UnavailableError('UnavailableError'), 14),
(mlmd_errors.ResourceExhaustedError('ResourceExhaustedError'), 8),
(
status_lib.StatusNotOkError(
code=status_lib.Code.DEADLINE_EXCEEDED,
message='DeadlineExceededError',
),
4,
),
(
status_lib.StatusNotOkError(
code=status_lib.Code.INTERNAL, message='InternalError'
),
13,
),
(
status_lib.StatusNotOkError(
code=status_lib.Code.UNAVAILABLE, message='UnavailableError'
),
14,
),
(
status_lib.StatusNotOkError(
code=status_lib.Code.RESOURCE_EXHAUSTED,
message='ResourceExhaustedError',
),
8,
),
)
@mock.patch.object(pstate.PipelineState, 'load_all_active_and_owned')
def test_orchestrate_pipelines_with_recoverable_error_from_MLMD(
self,
error,
error_code,
mock_load_all_active_and_owned,
):
mock_load_all_active_and_owned.side_effect = error

with test_utils.get_status_code_from_exception_environment(error_code):
with self._mlmd_cm as mlmd_connection_manager:
task_queue = tq.TaskQueue()
orchestrate_result = pipeline_ops.orchestrate(
mlmd_connection_manager,
task_queue,
service_jobs.DummyServiceJobManager(),
)
self.assertEqual(orchestrate_result, True)

@parameterized.parameters(
mlmd_errors.InvalidArgumentError('InvalidArgumentError'),
mlmd_errors.FailedPreconditionError('FailedPreconditionError'),
status_lib.StatusNotOkError(
code=status_lib.Code.INVALID_ARGUMENT, message='InvalidArgumentError'
),
status_lib.StatusNotOkError(
code=status_lib.Code.UNKNOWN,
message='UNKNOWN',
),
)
@mock.patch.object(pstate.PipelineState, 'load_all_active_and_owned')
def test_orchestrate_pipelines_with_not_recoverable_error_from_MLMD(
self, error, mock_load_all_active_and_owned
):
mock_load_all_active_and_owned.side_effect = error

with self._mlmd_cm as mlmd_connection_manager:
task_queue = tq.TaskQueue()
with self.assertRaises(Exception):
pipeline_ops.orchestrate(
mlmd_connection_manager,
task_queue,
service_jobs.DummyServiceJobManager(),
)


if __name__ == '__main__':
tf.test.main()
12 changes: 12 additions & 0 deletions tfx/orchestration/experimental/core/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,3 +511,15 @@ def prepare_orchestrator_for_pipeline_run(
pipeline.sdk_version = 'postprocessed'

return _TestEnv()


def get_status_code_from_exception_environment(error_code: int):

class _TestEnv(env._DefaultEnv): # pylint: disable=protected-access

def get_status_code_from_exception(
self, exception: Optional[BaseException]
) -> Optional[int]:
return error_code

return _TestEnv()
21 changes: 0 additions & 21 deletions tfx/utils/stats_utils.py

This file was deleted.

20 changes: 20 additions & 0 deletions tfx/utils/status.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,26 @@ class Code(enum.IntEnum):
UNAUTHENTICATED = 16


# These are the error codes that are retriable for USER_FACING traffic.
# See go/stubs-retries.
USER_FACING_RETRIABLE_STATUS_CODES = frozenset(
c.value
for c in [
Code.UNAVAILABLE,
]
)

BATCH_RETRIABLE_ERROR_CODES = frozenset(
c.value
for c in [
Code.DEADLINE_EXCEEDED,
Code.INTERNAL,
Code.UNAVAILABLE,
Code.RESOURCE_EXHAUSTED,
]
)


@attr.s(auto_attribs=True, frozen=True)
class Status:
"""Class to record status of operations.
Expand Down