Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions dvc/command/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,7 +747,9 @@ class CmdExperimentsRemove(CmdBase):
def run(self):

self.repo.experiments.remove(
exp_names=self.args.experiment, queue=self.args.queue
exp_names=self.args.experiment,
queue=self.args.queue,
clear_all=self.args.all,
)

return 0
Expand Down Expand Up @@ -1237,9 +1239,16 @@ def add_parser(subparsers, parent_parser):
help=EXPERIMENTS_REMOVE_HELP,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
experiments_remove_parser.add_argument(
remove_group = experiments_remove_parser.add_mutually_exclusive_group()
remove_group.add_argument(
"--queue", action="store_true", help="Remove all queued experiments."
)
remove_group.add_argument(
"-A",
"--all",
action="store_true",
help="Remove all committed experiments.",
)
experiments_remove_parser.add_argument(
"experiment",
nargs="*",
Expand Down
36 changes: 28 additions & 8 deletions dvc/repo/experiments/remove.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,29 @@
from dvc.scm.base import RevError

from .base import EXPS_NAMESPACE, ExpRefInfo
from .utils import exp_refs_by_name, remove_exp_refs
from .utils import exp_refs, exp_refs_by_name, remove_exp_refs

logger = logging.getLogger(__name__)


@locked
@scm_context
def remove(repo, exp_names=None, queue=False, **kwargs):
if not exp_names and not queue:
def remove(
repo,
exp_names=None,
queue=False,
clear_all=False,
**kwargs,
):
if not any([exp_names, queue, clear_all]):
return 0

removed = 0
if queue:
removed += len(repo.experiments.stash)
repo.experiments.stash.clear()
removed += _clear_stash(repo)
if clear_all:
removed += _clear_all(repo)

if exp_names:
remained = _remove_commited_exps(repo, exp_names)
remained = _remove_queued_exps(repo, remained)
Expand All @@ -33,6 +41,18 @@ def remove(repo, exp_names=None, queue=False, **kwargs):
return removed


def _clear_stash(repo):
removed = len(repo.experiments.stash)
repo.experiments.stash.clear()
return removed


def _clear_all(repo):
ref_infos = list(exp_refs(repo.scm))
remove_exp_refs(repo.scm, ref_infos)
return len(ref_infos)


def _get_exp_stash_index(repo, ref_or_rev: str) -> Optional[int]:
stash_revs = repo.experiments.stash_revs
for _, ref_info in stash_revs.items():
Expand All @@ -53,9 +73,9 @@ def _get_exp_ref(repo, exp_name: str) -> Optional[ExpRefInfo]:
if repo.scm.get_ref(exp_name):
return ExpRefInfo.from_ref(exp_name)
else:
exp_refs = list(exp_refs_by_name(repo.scm, exp_name))
if exp_refs:
return _get_ref(exp_refs, exp_name, cur_rev)
exp_ref_list = list(exp_refs_by_name(repo.scm, exp_name))
if exp_ref_list:
return _get_ref(exp_ref_list, exp_name, cur_rev)
return None


Expand Down
18 changes: 18 additions & 0 deletions tests/func/experiments/test_remove.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,21 @@ def test_remove_special_queued_experiments(tmp_dir, scm, dvc, exp_stage):
assert rev3 in dvc.experiments.stash_revs
assert scm.get_ref(str(ref_info1)) is None
assert scm.get_ref(str(ref_info2)) is not None


def test_remove_all(tmp_dir, scm, dvc, exp_stage):
results = dvc.experiments.run(exp_stage.addressing, params=["foo=1"])
ref_info1 = first(exp_refs_by_rev(scm, first(results)))
dvc.experiments.run(exp_stage.addressing, params=["foo=2"], queue=True)
scm.add(["dvc.yaml", "dvc.lock", "copy.py", "params.yaml", "metrics.yaml"])
scm.commit("update baseline")

results = dvc.experiments.run(exp_stage.addressing, params=["foo=3"])
ref_info2 = first(exp_refs_by_rev(scm, first(results)))
dvc.experiments.run(exp_stage.addressing, params=["foo=4"], queue=True)

removed = dvc.experiments.remove(clear_all=True)
assert removed == 2
assert len(dvc.experiments.stash) == 2
assert scm.get_ref(str(ref_info2)) is None
assert scm.get_ref(str(ref_info1)) is None
20 changes: 16 additions & 4 deletions tests/unit/command/test_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,13 +252,25 @@ def test_experiments_pull(dvc, scm, mocker):
)


def test_experiments_remove(dvc, scm, mocker):
cli_args = parse_args(["experiments", "remove", "--queue"])
@pytest.mark.parametrize(
"queue,clear_all",
[(True, False), (False, True)],
)
def test_experiments_remove(dvc, scm, mocker, queue, clear_all):
if queue:
args = "--queue"
if clear_all:
args = "--all"
cli_args = parse_args(["experiments", "remove", args])
assert cli_args.func == CmdExperimentsRemove

cmd = cli_args.func(cli_args)
m = mocker.patch("dvc.repo.experiments.remove.remove", return_value={})

assert cmd.run() == 0

m.assert_called_once_with(cmd.repo, exp_names=[], queue=True)
m.assert_called_once_with(
cmd.repo,
exp_names=[],
queue=queue,
clear_all=clear_all,
)