Skip to content

Commit

Permalink
No-op
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 621362399
  • Loading branch information
tfx-copybara committed Apr 4, 2024
1 parent 135a198 commit 201a6cc
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 0 deletions.
13 changes: 13 additions & 0 deletions tfx/orchestration/experimental/core/pipeline_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ def initiate_pipeline_start(
def stop_pipelines(
mlmd_handle: metadata.Metadata,
pipeline_uids: List[task_lib.PipelineUid],
skip_wait_for_inactivation: bool = False,
timeout_secs: Optional[float] = None,
ignore_non_existent_or_inactive: Optional[bool] = False,
) -> None:
Expand All @@ -270,6 +271,8 @@ def stop_pipelines(
Args:
mlmd_handle: A handle to the MLMD db.
pipeline_uids: UIDs of the pipeline to be stopped.
skip_wait_for_inactivation: Default False. If true, then prematurely returns
and skips waiting for all pipelines to be inactive.
timeout_secs: Amount of time in seconds total to wait for all pipelines to
stop. If `None`, waits indefinitely.
ignore_non_existent_or_inactive: If a pipeline is not found or inactive,
Expand Down Expand Up @@ -309,6 +312,14 @@ def stop_pipelines(
)
continue
raise e

if skip_wait_for_inactivation:
logging.info(
'Skipping wait for all pipelines to be inactive; pipeline ids: %s.',
pipeline_ids_str,
)
return

logging.info(
'Waiting for pipelines to be stopped; pipeline ids: %s', pipeline_ids_str
)
Expand Down Expand Up @@ -336,13 +347,15 @@ def _are_pipelines_inactivated() -> bool:
def stop_pipeline(
mlmd_handle: metadata.Metadata,
pipeline_uid: task_lib.PipelineUid,
skip_wait_for_inactivation: bool = False,
timeout_secs: Optional[float] = None,
) -> None:
"""Stops a single pipeline. Convenience wrapper around stop_pipelines."""
return stop_pipelines(
mlmd_handle=mlmd_handle,
pipeline_uids=[pipeline_uid],
timeout_secs=timeout_secs,
skip_wait_for_inactivation=skip_wait_for_inactivation,
)


Expand Down
35 changes: 35 additions & 0 deletions tfx/orchestration/experimental/core/pipeline_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1229,6 +1229,41 @@ def _inactivate(pipeline_state):

thread.join()

@parameterized.named_parameters(
dict(testcase_name='async', pipeline=_test_pipeline('pipeline1')),
dict(
testcase_name='sync',
pipeline=_test_pipeline('pipeline1', pipeline_pb2.Pipeline.SYNC),
),
)
def test_stop_pipeline_skip_wait_for_inactivation(self, pipeline):
with self._mlmd_connection as m:
mock_wait_for_predicate = self.enter_context(
mock.patch.object(pipeline_ops, '_wait_for_predicate', autospec=True)
)
pipeline_state = pipeline_ops.initiate_pipeline_start(m, pipeline)

def _inactivate(pipeline_state):
time.sleep(2.0)
with pipeline_ops._PIPELINE_OPS_LOCK:
with pipeline_state:
pipeline_state.set_pipeline_execution_state(
metadata_store_pb2.Execution.COMPLETE
)

thread = threading.Thread(target=_inactivate, args=(pipeline_state,))
thread.start()

pipeline_ops.stop_pipeline(
m,
task_lib.PipelineUid.from_pipeline(pipeline),
timeout_secs=20.0,
skip_wait_for_inactivation=True,
)
mock_wait_for_predicate.assert_not_called()

thread.join()

@parameterized.named_parameters(
dict(testcase_name='async', pipeline=_test_pipeline('pipeline1')),
dict(
Expand Down

0 comments on commit 201a6cc

Please sign in to comment.