Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions dvc/commands/queue/kill.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,32 @@ class CmdQueueKill(CmdBase):
"""Kill exp task in queue."""

def run(self):
self.repo.experiments.celery_queue.kill(revs=self.args.task)
self.repo.experiments.celery_queue.kill(
revs=self.args.task, force=self.args.force
)

return 0


def add_parser(queue_subparsers, parent_parser):
QUEUE_KILL_HELP = "Kill actively running experiment queue tasks."
QUEUE_KILL_HELP = (
"Gracefully interrupt running experiment queue tasks "
"(equivalent to Ctrl-C)"
)
queue_kill_parser = queue_subparsers.add_parser(
"kill",
parents=[parent_parser],
description=append_doc_link(QUEUE_KILL_HELP, "queue/kill"),
help=QUEUE_KILL_HELP,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
queue_kill_parser.add_argument(
"-f",
"--force",
action="store_true",
default=False,
help="Forcefully and immediately kill running experiment queue tasks",
)
queue_kill_parser.add_argument(
"task",
nargs="*",
Expand Down
18 changes: 10 additions & 8 deletions dvc/repo/experiments/queue/celery.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,12 +298,15 @@ def _get_running_task_ids(self) -> Set[str]:
return running_task_ids

def _try_to_kill_tasks(
self, to_kill: Dict[QueueEntry, str]
self, to_kill: Dict[QueueEntry, str], force: bool
) -> Dict[QueueEntry, str]:
fail_to_kill_entries: Dict[QueueEntry, str] = {}
for queue_entry, rev in to_kill.items():
try:
self.proc.kill(queue_entry.stash_rev)
if force:
self.proc.kill(queue_entry.stash_rev)
else:
self.proc.interrupt(queue_entry.stash_rev)
logger.debug(f"Task {rev} had been killed.")
except ProcessLookupError:
fail_to_kill_entries[queue_entry] = rev
Expand Down Expand Up @@ -331,20 +334,19 @@ def _mark_inactive_tasks_failure(self, remained_entries):
if remained_revs:
raise CannotKillTasksError(remained_revs)

def _kill_entries(self, entries: Dict[QueueEntry, str]):
def _kill_entries(self, entries: Dict[QueueEntry, str], force: bool):
logger.debug(
"Found active tasks: '%s' to kill",
list(entries.values()),
)
inactive_entries: Dict[QueueEntry, str] = self._try_to_kill_tasks(
entries
entries, force
)

if inactive_entries:
self._mark_inactive_tasks_failure(inactive_entries)

def kill(self, revs: Collection[str]) -> None:

def kill(self, revs: Collection[str], force: bool = False) -> None:
name_dict: Dict[
str, Optional[QueueEntry]
] = self.match_queue_entry_by_name(set(revs), self.iter_active())
Expand All @@ -360,7 +362,7 @@ def kill(self, revs: Collection[str]) -> None:
raise UnresolvedQueueExpNamesError(missing_revs)

if to_kill:
self._kill_entries(to_kill)
self._kill_entries(to_kill, force)

def shutdown(self, kill: bool = False):
self.celery.control.shutdown()
Expand All @@ -369,7 +371,7 @@ def shutdown(self, kill: bool = False):
for entry in self.iter_active():
to_kill[entry] = entry.name or entry.stash_rev
if to_kill:
self._kill_entries(to_kill)
self._kill_entries(to_kill, True)

def follow(
self,
Expand Down
2 changes: 1 addition & 1 deletion dvc/stage/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,9 @@ def _run(stage: "Stage", executable, cmd, checkpoint_func, **kwargs):
threading.current_thread(),
threading._MainThread, # type: ignore[attr-defined]
)
old_handler = None

exec_cmd = _make_cmd(executable, cmd)
old_handler = None
Comment on lines +87 to -89
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@karajan1001 What is this change about?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@karajan1001 If you can explain this change, I'm happy to approve. Thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually not anything affected by this. I will undo this.


try:
p = subprocess.Popen(exec_cmd, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ dependencies = [
"typing-extensions>=3.7.4",
"scmrepo==0.1.5",
"dvc-render==0.0.17",
"dvc-task==0.1.8",
"dvc-task==0.1.9",
"dvclive>=1.2.2",
"dvc-data==0.28.4",
"dvc-http==2.27.2",
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/command/test_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def test_experiments_kill(dvc, scm, mocker):
[
"queue",
"kill",
"--force",
"exp1",
"exp2",
]
Expand All @@ -105,7 +106,7 @@ def test_experiments_kill(dvc, scm, mocker):
)

assert cmd.run() == 0
m.assert_called_once_with(revs=["exp1", "exp2"])
m.assert_called_once_with(revs=["exp1", "exp2"], force=True)


def test_experiments_start(dvc, scm, mocker):
Expand Down
9 changes: 5 additions & 4 deletions tests/unit/repo/experiments/queue/test_celery.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_shutdown_with_kill(test_queue, mocker):

shutdown_spy.assert_called_once()
kill_spy.assert_called_once_with(
{mock_entry_foo: "foo", mock_entry_bar: "bar"}
{mock_entry_foo: "foo", mock_entry_bar: "bar"}, True
)


Expand Down Expand Up @@ -78,7 +78,8 @@ def test_post_run_after_kill(test_queue):
assert result_foo.get(timeout=10) == "foo"


def test_celery_queue_kill(test_queue, mocker):
@pytest.mark.parametrize("force", [True, False])
def test_celery_queue_kill(test_queue, mocker, force):

mock_entry_foo = mocker.Mock(stash_rev="foo")
mock_entry_bar = mocker.Mock(stash_rev="bar")
Expand Down Expand Up @@ -137,13 +138,13 @@ def kill_function(rev):

kill_mock = mocker.patch.object(
test_queue.proc,
"kill",
"kill" if force else "interrupt",
side_effect=mocker.MagicMock(side_effect=kill_function),
)
with pytest.raises(
CannotKillTasksError, match="Task 'foobar' is initializing,"
):
test_queue.kill(["bar", "foo", "foobar"])
test_queue.kill(["bar", "foo", "foobar"], force=force)
assert kill_mock.called_once_with(mock_entry_foo.stash_rev)
assert kill_mock.called_once_with(mock_entry_bar.stash_rev)
assert kill_mock.called_once_with(mock_entry_foobar.stash_rev)
Expand Down