diff --git a/dvc/command/experiments.py b/dvc/command/experiments.py index fa79072cf5..e1fa5d83f1 100644 --- a/dvc/command/experiments.py +++ b/dvc/command/experiments.py @@ -747,7 +747,9 @@ class CmdExperimentsRemove(CmdBase): def run(self): self.repo.experiments.remove( - exp_names=self.args.experiment, queue=self.args.queue + exp_names=self.args.experiment, + queue=self.args.queue, + clear_all=self.args.all, ) return 0 @@ -1237,9 +1239,16 @@ def add_parser(subparsers, parent_parser): help=EXPERIMENTS_REMOVE_HELP, formatter_class=argparse.RawDescriptionHelpFormatter, ) - experiments_remove_parser.add_argument( + remove_group = experiments_remove_parser.add_mutually_exclusive_group() + remove_group.add_argument( "--queue", action="store_true", help="Remove all queued experiments." ) + remove_group.add_argument( + "-A", + "--all", + action="store_true", + help="Remove all committed experiments.", + ) experiments_remove_parser.add_argument( "experiment", nargs="*", diff --git a/dvc/repo/experiments/remove.py b/dvc/repo/experiments/remove.py index 8755ac3004..2df54728bd 100644 --- a/dvc/repo/experiments/remove.py +++ b/dvc/repo/experiments/remove.py @@ -7,21 +7,29 @@ from dvc.scm.base import RevError from .base import EXPS_NAMESPACE, ExpRefInfo -from .utils import exp_refs_by_name, remove_exp_refs +from .utils import exp_refs, exp_refs_by_name, remove_exp_refs logger = logging.getLogger(__name__) @locked @scm_context -def remove(repo, exp_names=None, queue=False, **kwargs): - if not exp_names and not queue: +def remove( + repo, + exp_names=None, + queue=False, + clear_all=False, + **kwargs, +): + if not any([exp_names, queue, clear_all]): return 0 removed = 0 if queue: - removed += len(repo.experiments.stash) - repo.experiments.stash.clear() + removed += _clear_stash(repo) + if clear_all: + removed += _clear_all(repo) + if exp_names: remained = _remove_commited_exps(repo, exp_names) remained = _remove_queued_exps(repo, remained) @@ -33,6 +41,18 @@ def remove(repo, exp_names=None, queue=False, **kwargs): return removed +def _clear_stash(repo): + removed = len(repo.experiments.stash) + repo.experiments.stash.clear() + return removed + + +def _clear_all(repo): + ref_infos = list(exp_refs(repo.scm)) + remove_exp_refs(repo.scm, ref_infos) + return len(ref_infos) + + def _get_exp_stash_index(repo, ref_or_rev: str) -> Optional[int]: stash_revs = repo.experiments.stash_revs for _, ref_info in stash_revs.items(): @@ -53,9 +73,9 @@ def _get_exp_ref(repo, exp_name: str) -> Optional[ExpRefInfo]: if repo.scm.get_ref(exp_name): return ExpRefInfo.from_ref(exp_name) else: - exp_refs = list(exp_refs_by_name(repo.scm, exp_name)) - if exp_refs: - return _get_ref(exp_refs, exp_name, cur_rev) + exp_ref_list = list(exp_refs_by_name(repo.scm, exp_name)) + if exp_ref_list: + return _get_ref(exp_ref_list, exp_name, cur_rev) return None diff --git a/tests/func/experiments/test_remove.py b/tests/func/experiments/test_remove.py index f12f142532..909561a080 100644 --- a/tests/func/experiments/test_remove.py +++ b/tests/func/experiments/test_remove.py @@ -72,3 +72,21 @@ def test_remove_special_queued_experiments(tmp_dir, scm, dvc, exp_stage): assert rev3 in dvc.experiments.stash_revs assert scm.get_ref(str(ref_info1)) is None assert scm.get_ref(str(ref_info2)) is not None + + +def test_remove_all(tmp_dir, scm, dvc, exp_stage): + results = dvc.experiments.run(exp_stage.addressing, params=["foo=1"]) + ref_info1 = first(exp_refs_by_rev(scm, first(results))) + dvc.experiments.run(exp_stage.addressing, params=["foo=2"], queue=True) + scm.add(["dvc.yaml", "dvc.lock", "copy.py", "params.yaml", "metrics.yaml"]) + scm.commit("update baseline") + + results = dvc.experiments.run(exp_stage.addressing, params=["foo=3"]) + ref_info2 = first(exp_refs_by_rev(scm, first(results))) + dvc.experiments.run(exp_stage.addressing, params=["foo=4"], queue=True) + + removed = dvc.experiments.remove(clear_all=True) + assert removed == 2 + assert len(dvc.experiments.stash) == 2 + assert scm.get_ref(str(ref_info2)) is None + assert scm.get_ref(str(ref_info1)) is None diff --git a/tests/unit/command/test_experiments.py b/tests/unit/command/test_experiments.py index e271f14644..d8043af94b 100644 --- a/tests/unit/command/test_experiments.py +++ b/tests/unit/command/test_experiments.py @@ -252,13 +252,25 @@ def test_experiments_pull(dvc, scm, mocker): ) -def test_experiments_remove(dvc, scm, mocker): - cli_args = parse_args(["experiments", "remove", "--queue"]) +@pytest.mark.parametrize( + "queue,clear_all", + [(True, False), (False, True)], +) +def test_experiments_remove(dvc, scm, mocker, queue, clear_all): + if queue: + args = "--queue" + if clear_all: + args = "--all" + cli_args = parse_args(["experiments", "remove", args]) assert cli_args.func == CmdExperimentsRemove cmd = cli_args.func(cli_args) m = mocker.patch("dvc.repo.experiments.remove.remove", return_value={}) assert cmd.run() == 0 - - m.assert_called_once_with(cmd.repo, exp_names=[], queue=True) + m.assert_called_once_with( + cmd.repo, + exp_names=[], + queue=queue, + clear_all=clear_all, + )