Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 21 additions & 24 deletions dvc/repo/experiments/queue/celery.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Set,
)

from celery.result import AsyncResult
from funcy import cached_property
from kombu.message import Message

Expand All @@ -23,7 +24,6 @@

from ..exceptions import UnresolvedQueueExpNamesError
from ..executor.base import EXEC_TMP_DIR, ExecutorInfo, ExecutorResult
from ..stash import ExpStashEntry
from .base import BaseStashQueue, QueueDoneResult, QueueEntry, QueueGetResult
from .tasks import run_exp

Expand All @@ -41,7 +41,7 @@ class _MessageEntry(NamedTuple):


class _TaskEntry(NamedTuple):
task_id: str
async_result: AsyncResult
entry: QueueEntry


Expand Down Expand Up @@ -192,45 +192,46 @@ def _iter_processed(self) -> Generator[_MessageEntry, None, None]:
yield _MessageEntry(msg, QueueEntry.from_dict(entry_dict))

def _iter_active_tasks(self) -> Generator[_TaskEntry, None, None]:
from celery.result import AsyncResult

for msg, entry in self._iter_processed():
task_id = msg.headers["id"]
result: AsyncResult = AsyncResult(task_id)
if not result.ready():
yield _TaskEntry(task_id, entry)
yield _TaskEntry(result, entry)

def _iter_done_tasks(self) -> Generator[_TaskEntry, None, None]:
from celery.result import AsyncResult

for msg, entry in self._iter_processed():
task_id = msg.headers["id"]
result: AsyncResult = AsyncResult(task_id)
if result.ready():
yield _TaskEntry(task_id, entry)
yield _TaskEntry(result, entry)

def iter_active(self) -> Generator[QueueEntry, None, None]:
for _, entry in self._iter_active_tasks():
yield entry

def iter_done(self) -> Generator[QueueDoneResult, None, None]:
for _, entry in self._iter_done_tasks():
yield QueueDoneResult(entry, self.get_result(entry))
for result, entry in self._iter_done_tasks():
try:
exp_result = self.get_result(entry)
except FileNotFoundError:
if result.status == "SUCCESS":
raise DvcException(
f"Invalid experiment '{entry.stash_rev[:7]}'."
)
elif result.status == "FAILURE":
exp_result = None
yield QueueDoneResult(entry, exp_result)

def iter_success(self) -> Generator[QueueDoneResult, None, None]:
for queue_entry, exp_result in self.iter_done():
if exp_result and exp_result.exp_hash and exp_result.ref_info:
yield QueueDoneResult(queue_entry, exp_result)

def iter_failed(self) -> Generator[QueueDoneResult, None, None]:
failed_revs: Dict[str, ExpStashEntry] = (
dict(self.failed_stash.stash_revs)
if self.failed_stash is not None
else {}
)

for queue_entry, exp_result in self.iter_done():
if exp_result is None and queue_entry.stash_rev in failed_revs:
if exp_result is None:
yield QueueDoneResult(queue_entry, exp_result)

def reproduce(self) -> Mapping[str, Mapping[str, str]]:
Expand All @@ -240,7 +241,6 @@ def get_result(
self, entry: QueueEntry, timeout: Optional[float] = None
) -> Optional[ExecutorResult]:
from celery.exceptions import TimeoutError as _CeleryTimeout
from celery.result import AsyncResult

def _load_info(rev: str) -> ExecutorInfo:
infofile = self.get_infofile_path(rev)
Expand All @@ -261,11 +261,12 @@ def _load_collected(rev: str) -> Optional[ExecutorResult]:
for queue_entry in self.iter_queued():
if entry.stash_rev == queue_entry.stash_rev:
raise DvcException("Experiment has not been started.")
for task_id, active_entry in self._iter_active_tasks():
for result, active_entry in self._iter_active_tasks():
if entry.stash_rev == active_entry.stash_rev:
logger.debug("Waiting for exp task '%s' to complete", task_id)
logger.debug(
"Waiting for exp task '%s' to complete", result.id
)
try:
result: AsyncResult = AsyncResult(task_id)
result.get(timeout=timeout)
except _CeleryTimeout as exc:
raise DvcException(
Expand All @@ -277,11 +278,7 @@ def _load_collected(rev: str) -> Optional[ExecutorResult]:
# NOTE: It's possible for an exp to complete while iterating through
# other queued and active tasks, in which case the exp will get moved
# out of the active task list, and needs to be loaded here.
try:
return _load_collected(entry.stash_rev)
except FileNotFoundError:
pass
raise DvcException(f"Invalid experiment '{entry.stash_rev[:7]}'.")
return _load_collected(entry.stash_rev)

def kill(self, revs: Collection[str]) -> None:
to_kill: Set[QueueEntry] = set()
Expand Down
27 changes: 26 additions & 1 deletion tests/unit/repo/experiments/queue/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
from celery import shared_task
from flaky.flaky_decorator import flaky

from dvc.exceptions import DvcException
from dvc.repo.experiments.exceptions import UnresolvedExpNamesError
from dvc.repo.experiments.executor.local import TempDirExecutor
from dvc.repo.experiments.queue.base import QueueDoneResult
from dvc.repo.experiments.refs import EXEC_BASELINE, EXEC_HEAD, EXEC_MERGE


Expand Down Expand Up @@ -37,7 +39,7 @@ def test_shutdown_with_kill(test_queue, mocker):
mocker.patch.object(
test_queue,
"_iter_active_tasks",
return_value=[(result.id, mock_entry)],
return_value=[(result, mock_entry)],
)
kill_spy = mocker.patch.object(test_queue.proc, "kill")

Expand Down Expand Up @@ -120,3 +122,26 @@ def test_queue_clean_workspace_refs(git_dir, tmp_dir):

for ref in exec_heads:
assert git_dir.scm.get_ref(ref) is None


@pytest.mark.parametrize("status", ["FAILURE", "SUCCESS"])
def test_queue_iter_done_task(test_queue, mocker, status):

mock_entry = mocker.Mock(stash_rev=_foo.name)

result = mocker.Mock(status=status)

mocker.patch.object(
test_queue,
"_iter_done_tasks",
return_value=[(result, mock_entry)],
)

if status == "FAILURE":
assert list(test_queue.iter_failed()) == [
QueueDoneResult(mock_entry, None)
]

elif status == "SUCCESS":
with pytest.raises(DvcException):
assert list(test_queue.iter_success())