diff --git a/dvc/repo/experiments/remove.py b/dvc/repo/experiments/remove.py index a919608bb6..8755ac3004 100644 --- a/dvc/repo/experiments/remove.py +++ b/dvc/repo/experiments/remove.py @@ -1,8 +1,10 @@ import logging +from typing import List, Optional from dvc.exceptions import InvalidArgumentError from dvc.repo import locked from dvc.repo.scm_context import scm_context +from dvc.scm.base import RevError from .base import EXPS_NAMESPACE, ExpRefInfo from .utils import exp_refs_by_name, remove_exp_refs @@ -21,32 +23,43 @@ def remove(repo, exp_names=None, queue=False, **kwargs): removed += len(repo.experiments.stash) repo.experiments.stash.clear() if exp_names: - ref_infos = list(_get_exp_refs(repo, exp_names)) - remove_exp_refs(repo.scm, ref_infos) - removed += len(ref_infos) + remained = _remove_commited_exps(repo, exp_names) + remained = _remove_queued_exps(repo, remained) + if remained: + raise InvalidArgumentError( + "'{}' is not a valid experiment".format(";".join(remained)) + ) + removed += len(exp_names) - len(remained) return removed -def _get_exp_refs(repo, exp_names): - cur_rev = repo.scm.get_rev() - for name in exp_names: - if name.startswith(EXPS_NAMESPACE): - if not repo.scm.get_ref(name): - raise InvalidArgumentError( - f"'{name}' is not a valid experiment name" - ) - yield ExpRefInfo.from_ref(name) - else: +def _get_exp_stash_index(repo, ref_or_rev: str) -> Optional[int]: + stash_revs = repo.experiments.stash_revs + for _, ref_info in stash_revs.items(): + if ref_info.name == ref_or_rev: + return ref_info.index + try: + rev = repo.scm.resolve_rev(ref_or_rev) + if rev in stash_revs: + return stash_revs.get(rev).index + except RevError: + pass + return None - exp_refs = list(exp_refs_by_name(repo.scm, name)) - if not exp_refs: - raise InvalidArgumentError( - f"'{name}' is not a valid experiment name" - ) - yield _get_ref(exp_refs, name, cur_rev) + +def _get_exp_ref(repo, exp_name: str) -> Optional[ExpRefInfo]: + cur_rev = repo.scm.get_rev() + if exp_name.startswith(EXPS_NAMESPACE): + if repo.scm.get_ref(exp_name): + return ExpRefInfo.from_ref(exp_name) + else: + exp_refs = list(exp_refs_by_name(repo.scm, exp_name)) + if exp_refs: + return _get_ref(exp_refs, exp_name, cur_rev) + return None -def _get_ref(ref_infos, name, cur_rev): +def _get_ref(ref_infos, name, cur_rev) -> Optional[ExpRefInfo]: if len(ref_infos) > 1: for info in ref_infos: if info.baseline_sha == cur_rev: @@ -61,3 +74,28 @@ def _get_ref(ref_infos, name, cur_rev): msg.extend([f"\t{info}" for info in ref_infos]) raise InvalidArgumentError("\n".join(msg)) return ref_infos[0] + + +def _remove_commited_exps(repo, refs: List[str]) -> List[str]: + remain_list = [] + remove_list = [] + for ref in refs: + ref_info = _get_exp_ref(repo, ref) + if ref_info: + remove_list.append(ref_info) + else: + remain_list.append(ref) + if remove_list: + remove_exp_refs(repo.scm, remove_list) + return remain_list + + +def _remove_queued_exps(repo, refs_or_revs: List[str]) -> List[str]: + remain_list = [] + for ref_or_rev in refs_or_revs: + stash_index = _get_exp_stash_index(repo, ref_or_rev) + if stash_index is None: + remain_list.append(ref_or_rev) + else: + repo.experiments.stash.drop(stash_index) + return remain_list diff --git a/tests/func/experiments/test_experiments.py b/tests/func/experiments/test_experiments.py index 47451f4cb3..5a755f8042 100644 --- a/tests/func/experiments/test_experiments.py +++ b/tests/func/experiments/test_experiments.py @@ -576,21 +576,6 @@ def test_run_metrics(tmp_dir, scm, dvc, exp_stage, mocker): assert show_mock.called_once() -def test_remove(tmp_dir, scm, dvc, exp_stage): - results = dvc.experiments.run(exp_stage.addressing, params=["foo=2"]) - exp = first(results) - ref_info = first(exp_refs_by_rev(scm, exp)) - dvc.experiments.run(exp_stage.addressing, params=["foo=3"], queue=True) - - removed = dvc.experiments.remove([str(ref_info)]) - assert removed == 1 - assert scm.get_ref(str(ref_info)) is None - - removed = dvc.experiments.remove(queue=True) - assert removed == 1 - assert len(dvc.experiments.stash) == 0 - - def test_checkout_targets_deps(tmp_dir, scm, dvc, exp_stage): from dvc.utils.fs import remove diff --git a/tests/func/experiments/test_remove.py b/tests/func/experiments/test_remove.py new file mode 100644 index 0000000000..f12f142532 --- /dev/null +++ b/tests/func/experiments/test_remove.py @@ -0,0 +1,74 @@ +import pytest +from funcy import first + +from dvc.exceptions import InvalidArgumentError +from dvc.repo.experiments.utils import exp_refs_by_rev + + +def test_remove_experiments_by_ref(tmp_dir, scm, dvc, exp_stage, caplog): + queue_length = 3 + ref_list = [] + + for i in range(queue_length): + results = dvc.experiments.run( + exp_stage.addressing, params=[f"foo={i}"] + ) + ref_info = first(exp_refs_by_rev(scm, first(results))) + ref_list.append(str(ref_info)) + + with pytest.raises(InvalidArgumentError): + assert dvc.experiments.remove(ref_list[:2] + ["non-exist"]) + assert scm.get_ref(str(ref_list[0])) is None + assert scm.get_ref(str(ref_list[1])) is None + assert scm.get_ref(str(ref_list[2])) is not None + + +def test_remove_all_queued_experiments(tmp_dir, scm, dvc, exp_stage): + queue_length = 3 + + for i in range(queue_length): + dvc.experiments.run( + exp_stage.addressing, params=[f"foo={i}"], queue=True + ) + + results = dvc.experiments.run( + exp_stage.addressing, params=[f"foo={queue_length}"] + ) + ref_info = first(exp_refs_by_rev(scm, first(results))) + + assert len(dvc.experiments.stash) == queue_length + assert dvc.experiments.remove(queue=True) == queue_length + assert len(dvc.experiments.stash) == 0 + assert scm.get_ref(str(ref_info)) is not None + + +def test_remove_special_queued_experiments(tmp_dir, scm, dvc, exp_stage): + results = dvc.experiments.run( + exp_stage.addressing, params=["foo=1"], queue=True, name="queue1" + ) + rev1 = first(results) + results = dvc.experiments.run( + exp_stage.addressing, params=["foo=2"], queue=True, name="queue2" + ) + rev2 = first(results) + results = dvc.experiments.run( + exp_stage.addressing, params=["foo=3"], queue=True, name="queue3" + ) + rev3 = first(results) + results = dvc.experiments.run(exp_stage.addressing, params=["foo=4"]) + ref_info1 = first(exp_refs_by_rev(scm, first(results))) + results = dvc.experiments.run(exp_stage.addressing, params=["foo=5"]) + ref_info2 = first(exp_refs_by_rev(scm, first(results))) + + assert rev1 in dvc.experiments.stash_revs + assert rev2 in dvc.experiments.stash_revs + assert rev3 in dvc.experiments.stash_revs + assert scm.get_ref(str(ref_info1)) is not None + assert scm.get_ref(str(ref_info2)) is not None + + assert dvc.experiments.remove(["queue1", rev2[:5], str(ref_info1)]) == 3 + assert rev1 not in dvc.experiments.stash_revs + assert rev2 not in dvc.experiments.stash_revs + assert rev3 in dvc.experiments.stash_revs + assert scm.get_ref(str(ref_info1)) is None + assert scm.get_ref(str(ref_info2)) is not None