Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 58 additions & 20 deletions dvc/repo/experiments/remove.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import logging
from typing import List, Optional

from dvc.exceptions import InvalidArgumentError
from dvc.repo import locked
from dvc.repo.scm_context import scm_context
from dvc.scm.base import RevError

from .base import EXPS_NAMESPACE, ExpRefInfo
from .utils import exp_refs_by_name, remove_exp_refs
Expand All @@ -21,32 +23,43 @@ def remove(repo, exp_names=None, queue=False, **kwargs):
removed += len(repo.experiments.stash)
repo.experiments.stash.clear()
if exp_names:
ref_infos = list(_get_exp_refs(repo, exp_names))
remove_exp_refs(repo.scm, ref_infos)
removed += len(ref_infos)
remained = _remove_commited_exps(repo, exp_names)
remained = _remove_queued_exps(repo, remained)
if remained:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A behavior change here, because we had already removed all of the matched experiments, raise an exception here is needless.

raise InvalidArgumentError(
"'{}' is not a valid experiment".format(";".join(remained))
)
removed += len(exp_names) - len(remained)
return removed


def _get_exp_refs(repo, exp_names):
cur_rev = repo.scm.get_rev()
for name in exp_names:
if name.startswith(EXPS_NAMESPACE):
if not repo.scm.get_ref(name):
raise InvalidArgumentError(
f"'{name}' is not a valid experiment name"
)
yield ExpRefInfo.from_ref(name)
else:
def _get_exp_stash_index(repo, ref_or_rev: str) -> Optional[int]:
stash_revs = repo.experiments.stash_revs
for _, ref_info in stash_revs.items():
if ref_info.name == ref_or_rev:
return ref_info.index
try:
rev = repo.scm.resolve_rev(ref_or_rev)
if rev in stash_revs:
return stash_revs.get(rev).index
except RevError:
pass
return None

exp_refs = list(exp_refs_by_name(repo.scm, name))
if not exp_refs:
raise InvalidArgumentError(
f"'{name}' is not a valid experiment name"
)
yield _get_ref(exp_refs, name, cur_rev)

def _get_exp_ref(repo, exp_name: str) -> Optional[ExpRefInfo]:
cur_rev = repo.scm.get_rev()
if exp_name.startswith(EXPS_NAMESPACE):
if repo.scm.get_ref(exp_name):
return ExpRefInfo.from_ref(exp_name)
else:
exp_refs = list(exp_refs_by_name(repo.scm, exp_name))
if exp_refs:
return _get_ref(exp_refs, exp_name, cur_rev)
return None


def _get_ref(ref_infos, name, cur_rev):
def _get_ref(ref_infos, name, cur_rev) -> Optional[ExpRefInfo]:
if len(ref_infos) > 1:
for info in ref_infos:
if info.baseline_sha == cur_rev:
Expand All @@ -61,3 +74,28 @@ def _get_ref(ref_infos, name, cur_rev):
msg.extend([f"\t{info}" for info in ref_infos])
raise InvalidArgumentError("\n".join(msg))
return ref_infos[0]


def _remove_commited_exps(repo, refs: List[str]) -> List[str]:
remain_list = []
remove_list = []
for ref in refs:
ref_info = _get_exp_ref(repo, ref)
if ref_info:
remove_list.append(ref_info)
else:
remain_list.append(ref)
if remove_list:
remove_exp_refs(repo.scm, remove_list)
return remain_list


def _remove_queued_exps(repo, refs_or_revs: List[str]) -> List[str]:
remain_list = []
for ref_or_rev in refs_or_revs:
stash_index = _get_exp_stash_index(repo, ref_or_rev)
if stash_index is None:
remain_list.append(ref_or_rev)
else:
repo.experiments.stash.drop(stash_index)
return remain_list
15 changes: 0 additions & 15 deletions tests/func/experiments/test_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,21 +576,6 @@ def test_run_metrics(tmp_dir, scm, dvc, exp_stage, mocker):
assert show_mock.called_once()


def test_remove(tmp_dir, scm, dvc, exp_stage):
results = dvc.experiments.run(exp_stage.addressing, params=["foo=2"])
exp = first(results)
ref_info = first(exp_refs_by_rev(scm, exp))
dvc.experiments.run(exp_stage.addressing, params=["foo=3"], queue=True)

removed = dvc.experiments.remove([str(ref_info)])
assert removed == 1
assert scm.get_ref(str(ref_info)) is None

removed = dvc.experiments.remove(queue=True)
assert removed == 1
assert len(dvc.experiments.stash) == 0


def test_checkout_targets_deps(tmp_dir, scm, dvc, exp_stage):
from dvc.utils.fs import remove

Expand Down
74 changes: 74 additions & 0 deletions tests/func/experiments/test_remove.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import pytest
from funcy import first

from dvc.exceptions import InvalidArgumentError
from dvc.repo.experiments.utils import exp_refs_by_rev


def test_remove_experiments_by_ref(tmp_dir, scm, dvc, exp_stage, caplog):
queue_length = 3
ref_list = []

for i in range(queue_length):
results = dvc.experiments.run(
exp_stage.addressing, params=[f"foo={i}"]
)
ref_info = first(exp_refs_by_rev(scm, first(results)))
ref_list.append(str(ref_info))

with pytest.raises(InvalidArgumentError):
assert dvc.experiments.remove(ref_list[:2] + ["non-exist"])
assert scm.get_ref(str(ref_list[0])) is None
assert scm.get_ref(str(ref_list[1])) is None
assert scm.get_ref(str(ref_list[2])) is not None


def test_remove_all_queued_experiments(tmp_dir, scm, dvc, exp_stage):
queue_length = 3

for i in range(queue_length):
dvc.experiments.run(
exp_stage.addressing, params=[f"foo={i}"], queue=True
)

results = dvc.experiments.run(
exp_stage.addressing, params=[f"foo={queue_length}"]
)
ref_info = first(exp_refs_by_rev(scm, first(results)))

assert len(dvc.experiments.stash) == queue_length
assert dvc.experiments.remove(queue=True) == queue_length
assert len(dvc.experiments.stash) == 0
assert scm.get_ref(str(ref_info)) is not None


def test_remove_special_queued_experiments(tmp_dir, scm, dvc, exp_stage):
results = dvc.experiments.run(
exp_stage.addressing, params=["foo=1"], queue=True, name="queue1"
)
rev1 = first(results)
results = dvc.experiments.run(
exp_stage.addressing, params=["foo=2"], queue=True, name="queue2"
)
rev2 = first(results)
results = dvc.experiments.run(
exp_stage.addressing, params=["foo=3"], queue=True, name="queue3"
)
rev3 = first(results)
results = dvc.experiments.run(exp_stage.addressing, params=["foo=4"])
ref_info1 = first(exp_refs_by_rev(scm, first(results)))
results = dvc.experiments.run(exp_stage.addressing, params=["foo=5"])
ref_info2 = first(exp_refs_by_rev(scm, first(results)))

assert rev1 in dvc.experiments.stash_revs
assert rev2 in dvc.experiments.stash_revs
assert rev3 in dvc.experiments.stash_revs
assert scm.get_ref(str(ref_info1)) is not None
assert scm.get_ref(str(ref_info2)) is not None

assert dvc.experiments.remove(["queue1", rev2[:5], str(ref_info1)]) == 3
assert rev1 not in dvc.experiments.stash_revs
assert rev2 not in dvc.experiments.stash_revs
assert rev3 in dvc.experiments.stash_revs
assert scm.get_ref(str(ref_info1)) is None
assert scm.get_ref(str(ref_info2)) is not None