From 7d0e9ba86e60a3880f54e64999b3d570940c12e0 Mon Sep 17 00:00:00 2001 From: karajan1001 Date: Fri, 5 Aug 2022 21:13:09 +0800 Subject: [PATCH] queue status: ERROR: Invalid experiment fix: #8014 > ERROR: Invalid experiment '{entry.stash_rev[:7]}'. This happens when the queue task failed from a scm error. {"exc_type": "GitMergeError", "exc_message": ["Cannot fast-forward HEAD to '05100047a341f2fa4a02421289d48f84d8c45e86'"], "exc_module": "dvc.scm"}. And the failed task neither create a infofile nor create a fail_stash. 1. Didn't raise DvcException if no infofile found for failed tasks. (success tasks still raise it) 2. Add a new unit test for this. --- dvc/repo/experiments/queue/celery.py | 45 +++++++++---------- .../unit/repo/experiments/queue/test_local.py | 27 ++++++++++- 2 files changed, 47 insertions(+), 25 deletions(-) diff --git a/dvc/repo/experiments/queue/celery.py b/dvc/repo/experiments/queue/celery.py index 50bf85fd7f..05eee9b286 100644 --- a/dvc/repo/experiments/queue/celery.py +++ b/dvc/repo/experiments/queue/celery.py @@ -14,6 +14,7 @@ Set, ) +from celery.result import AsyncResult from funcy import cached_property from kombu.message import Message @@ -23,7 +24,6 @@ from ..exceptions import UnresolvedQueueExpNamesError from ..executor.base import EXEC_TMP_DIR, ExecutorInfo, ExecutorResult -from ..stash import ExpStashEntry from .base import BaseStashQueue, QueueDoneResult, QueueEntry, QueueGetResult from .tasks import run_exp @@ -41,7 +41,7 @@ class _MessageEntry(NamedTuple): class _TaskEntry(NamedTuple): - task_id: str + async_result: AsyncResult entry: QueueEntry @@ -192,30 +192,37 @@ def _iter_processed(self) -> Generator[_MessageEntry, None, None]: yield _MessageEntry(msg, QueueEntry.from_dict(entry_dict)) def _iter_active_tasks(self) -> Generator[_TaskEntry, None, None]: - from celery.result import AsyncResult for msg, entry in self._iter_processed(): task_id = msg.headers["id"] result: AsyncResult = AsyncResult(task_id) if not result.ready(): - yield _TaskEntry(task_id, entry) + yield _TaskEntry(result, entry) def _iter_done_tasks(self) -> Generator[_TaskEntry, None, None]: - from celery.result import AsyncResult for msg, entry in self._iter_processed(): task_id = msg.headers["id"] result: AsyncResult = AsyncResult(task_id) if result.ready(): - yield _TaskEntry(task_id, entry) + yield _TaskEntry(result, entry) def iter_active(self) -> Generator[QueueEntry, None, None]: for _, entry in self._iter_active_tasks(): yield entry def iter_done(self) -> Generator[QueueDoneResult, None, None]: - for _, entry in self._iter_done_tasks(): - yield QueueDoneResult(entry, self.get_result(entry)) + for result, entry in self._iter_done_tasks(): + try: + exp_result = self.get_result(entry) + except FileNotFoundError: + if result.status == "SUCCESS": + raise DvcException( + f"Invalid experiment '{entry.stash_rev[:7]}'." + ) + elif result.status == "FAILURE": + exp_result = None + yield QueueDoneResult(entry, exp_result) def iter_success(self) -> Generator[QueueDoneResult, None, None]: for queue_entry, exp_result in self.iter_done(): @@ -223,14 +230,8 @@ def iter_success(self) -> Generator[QueueDoneResult, None, None]: yield QueueDoneResult(queue_entry, exp_result) def iter_failed(self) -> Generator[QueueDoneResult, None, None]: - failed_revs: Dict[str, ExpStashEntry] = ( - dict(self.failed_stash.stash_revs) - if self.failed_stash is not None - else {} - ) - for queue_entry, exp_result in self.iter_done(): - if exp_result is None and queue_entry.stash_rev in failed_revs: + if exp_result is None: yield QueueDoneResult(queue_entry, exp_result) def reproduce(self) -> Mapping[str, Mapping[str, str]]: @@ -240,7 +241,6 @@ def get_result( self, entry: QueueEntry, timeout: Optional[float] = None ) -> Optional[ExecutorResult]: from celery.exceptions import TimeoutError as _CeleryTimeout - from celery.result import AsyncResult def _load_info(rev: str) -> ExecutorInfo: infofile = self.get_infofile_path(rev) @@ -261,11 +261,12 @@ def _load_collected(rev: str) -> Optional[ExecutorResult]: for queue_entry in self.iter_queued(): if entry.stash_rev == queue_entry.stash_rev: raise DvcException("Experiment has not been started.") - for task_id, active_entry in self._iter_active_tasks(): + for result, active_entry in self._iter_active_tasks(): if entry.stash_rev == active_entry.stash_rev: - logger.debug("Waiting for exp task '%s' to complete", task_id) + logger.debug( + "Waiting for exp task '%s' to complete", result.id + ) try: - result: AsyncResult = AsyncResult(task_id) result.get(timeout=timeout) except _CeleryTimeout as exc: raise DvcException( @@ -277,11 +278,7 @@ def _load_collected(rev: str) -> Optional[ExecutorResult]: # NOTE: It's possible for an exp to complete while iterating through # other queued and active tasks, in which case the exp will get moved # out of the active task list, and needs to be loaded here. - try: - return _load_collected(entry.stash_rev) - except FileNotFoundError: - pass - raise DvcException(f"Invalid experiment '{entry.stash_rev[:7]}'.") + return _load_collected(entry.stash_rev) def kill(self, revs: Collection[str]) -> None: to_kill: Set[QueueEntry] = set() diff --git a/tests/unit/repo/experiments/queue/test_local.py b/tests/unit/repo/experiments/queue/test_local.py index 8d38af002f..3ebaf3aa9b 100644 --- a/tests/unit/repo/experiments/queue/test_local.py +++ b/tests/unit/repo/experiments/queue/test_local.py @@ -4,8 +4,10 @@ from celery import shared_task from flaky.flaky_decorator import flaky +from dvc.exceptions import DvcException from dvc.repo.experiments.exceptions import UnresolvedExpNamesError from dvc.repo.experiments.executor.local import TempDirExecutor +from dvc.repo.experiments.queue.base import QueueDoneResult from dvc.repo.experiments.refs import EXEC_BASELINE, EXEC_HEAD, EXEC_MERGE @@ -37,7 +39,7 @@ def test_shutdown_with_kill(test_queue, mocker): mocker.patch.object( test_queue, "_iter_active_tasks", - return_value=[(result.id, mock_entry)], + return_value=[(result, mock_entry)], ) kill_spy = mocker.patch.object(test_queue.proc, "kill") @@ -120,3 +122,26 @@ def test_queue_clean_workspace_refs(git_dir, tmp_dir): for ref in exec_heads: assert git_dir.scm.get_ref(ref) is None + + +@pytest.mark.parametrize("status", ["FAILURE", "SUCCESS"]) +def test_queue_iter_done_task(test_queue, mocker, status): + + mock_entry = mocker.Mock(stash_rev=_foo.name) + + result = mocker.Mock(status=status) + + mocker.patch.object( + test_queue, + "_iter_done_tasks", + return_value=[(result, mock_entry)], + ) + + if status == "FAILURE": + assert list(test_queue.iter_failed()) == [ + QueueDoneResult(mock_entry, None) + ] + + elif status == "SUCCESS": + with pytest.raises(DvcException): + assert list(test_queue.iter_success())