From 54f3e5304dcd40f235a3bbed25c0ecd6c20697a7 Mon Sep 17 00:00:00 2001 From: karajan1001 Date: Sat, 25 Dec 2021 10:55:42 +0800 Subject: [PATCH] exp show: add `--rev` flag (#7152) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit fix: #7152 1. Add `--rev` flag to `dvc exp show`. 2. Add -1 support for `--num` flag`. 3. Extract revision found logic to `utils`. 4. Brancher use this revision found logic. 5. Update command unit test 6. Add a util unit test. Co-authored-by: Peter Rowlands (변기호) --- dvc/command/experiments/show.py | 15 +++++- dvc/repo/brancher.py | 51 ++++++++------------- dvc/repo/experiments/show.py | 47 ++++++------------- dvc/scm.py | 59 +++++++++++++++++++++++- tests/func/experiments/test_show.py | 3 +- tests/unit/command/test_experiments.py | 5 +- tests/unit/scm/test_scm.py | 63 ++++++++++++++++++++++++++ 7 files changed, 173 insertions(+), 70 deletions(-) create mode 100644 tests/unit/scm/test_scm.py diff --git a/dvc/command/experiments/show.py b/dvc/command/experiments/show.py index 3ed78b5a4a..76eb5c9c49 100644 --- a/dvc/command/experiments/show.py +++ b/dvc/command/experiments/show.py @@ -532,8 +532,9 @@ def run(self): all_branches=self.args.all_branches, all_tags=self.args.all_tags, all_commits=self.args.all_commits, - sha_only=self.args.sha, + revs=self.args.rev, num=self.args.num, + sha_only=self.args.sha, param_deps=self.args.param_deps, ) except DvcException: @@ -604,6 +605,16 @@ def add_parser(experiments_subparsers, parent_parser): default=False, help="Show experiments derived from all Git commits.", ) + experiments_show_parser.add_argument( + "--rev", + type=str, + default=None, + help=( + "Show experiments derived from the specified revision. " + "Defaults to HEAD if none of `--rev`,`-a`,`-A`,`-T` is specified." + ), + metavar="", + ) experiments_show_parser.add_argument( "-n", "--num", @@ -611,7 +622,7 @@ def add_parser(experiments_subparsers, parent_parser): default=1, dest="num", metavar="", - help="Show the last `num` commits from HEAD.", + help="Show the last `num` commits from .", ) experiments_show_parser.add_argument( "--no-pager", diff --git a/dvc/repo/brancher.py b/dvc/repo/brancher.py index 5002632c34..30e6d758c2 100644 --- a/dvc/repo/brancher.py +++ b/dvc/repo/brancher.py @@ -1,6 +1,4 @@ -from functools import partial - -from funcy import group_by +from dvc.scm import iter_revs def brancher( # noqa: E302 @@ -34,44 +32,33 @@ def brancher( # noqa: E302 from dvc.fs.local import LocalFileSystem saved_fs = self.fs - revs = revs.copy() if revs else [] scm = self.scm self.fs = LocalFileSystem(url=self.root_dir) yield "workspace" - if revs and "workspace" in revs: + revs = revs.copy() if revs else [] + if "workspace" in revs: revs.remove("workspace") - if all_commits: - revs = scm.list_all_commits() - else: - if all_branches: - revs.extend(scm.list_branches()) - - if all_tags: - revs.extend(scm.list_tags()) - - if all_experiments: - from dvc.repo.experiments.utils import exp_commits - - revs.extend(exp_commits(scm)) + found_revs = iter_revs( + scm, + revs, + all_branches=all_branches, + all_tags=all_tags, + all_commits=all_commits, + all_experiments=all_experiments, + ) try: - if revs: - from dvc.fs.git import GitFileSystem - from dvc.scm import resolve_rev - - rev_resolver = partial(resolve_rev, scm) - for sha, names in group_by(rev_resolver, revs).items(): - self.fs = GitFileSystem(scm=scm, rev=sha) - # ignore revs that don't contain repo root - # (i.e. revs from before a subdir=True repo was init'ed) - if self.fs.exists(self.root_dir): - if sha_only: - yield sha - else: - yield ", ".join(names) + from dvc.fs.git import GitFileSystem + + for sha, names in found_revs.items(): + self.fs = GitFileSystem(scm=scm, rev=sha) + # ignore revs that don't contain repo root + # (i.e. revs from before a subdir=True repo was init'ed) + if self.fs.exists(self.root_dir): + yield sha if sha_only else ",".join(names) finally: self.fs = saved_fs diff --git a/dvc/repo/experiments/show.py b/dvc/repo/experiments/show.py index b65d7016c0..2a4a96730f 100644 --- a/dvc/repo/experiments/show.py +++ b/dvc/repo/experiments/show.py @@ -1,14 +1,13 @@ import logging from collections import OrderedDict, defaultdict from datetime import datetime -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, List, Optional, Union -from dvc.exceptions import InvalidArgumentError -from dvc.repo import locked +from dvc.repo import Repo, locked # pylint: disable=unused-import from dvc.repo.experiments.base import ExpRefInfo -from dvc.repo.experiments.utils import fix_exp_head from dvc.repo.metrics.show import _gather_metrics from dvc.repo.params.show import _gather_params +from dvc.scm import iter_revs from dvc.utils import error_handler, onerror_collect logger = logging.getLogger(__name__) @@ -102,10 +101,10 @@ def _collect_experiment_branch( @locked def show( - repo, + repo: "Repo", all_branches=False, all_tags=False, - revs=None, + revs: Union[List[str], str, None] = None, all_commits=False, sha_only=False, num=1, @@ -117,35 +116,19 @@ def show( res: Dict[str, Dict] = defaultdict(OrderedDict) - if num < 1: - raise InvalidArgumentError(f"Invalid number of commits '{num}'") - - if revs is None: - from dvc.scm import RevError, resolve_rev - - revs = [] - for n in range(num): - try: - head = fix_exp_head(repo.scm, f"HEAD~{n}") - assert head - revs.append(resolve_rev(repo.scm, head)) - except RevError: - break - - revs = OrderedDict( - (rev, None) - for rev in repo.brancher( - revs=revs, - all_branches=all_branches, - all_tags=all_tags, - all_commits=all_commits, - sha_only=True, - ) + if not any([revs, all_branches, all_tags, all_commits]): + revs = ["HEAD"] + if isinstance(revs, str): + revs = [revs] + + found_revs: Dict[str, List[str]] = {"workspace": []} + found_revs.update( + iter_revs(repo.scm, revs, num, all_branches, all_tags, all_commits) ) running = repo.experiments.get_running_exps() - for rev in revs: + for rev in found_revs: res[rev]["baseline"] = _collect_experiment_commit( repo, rev, @@ -180,7 +163,7 @@ def show( ) # collect queued (not yet reproduced) experiments for stash_rev, entry in repo.experiments.stash_revs.items(): - if entry.baseline_rev in revs: + if entry.baseline_rev in found_revs: if stash_rev not in running or not running[stash_rev].get( "last" ): diff --git a/dvc/scm.py b/dvc/scm.py index c7b6ab551d..064174dae2 100644 --- a/dvc/scm.py +++ b/dvc/scm.py @@ -1,12 +1,14 @@ """Manages source control systems (e.g. Git).""" from contextlib import contextmanager -from typing import TYPE_CHECKING, Iterator +from functools import partial +from typing import TYPE_CHECKING, Iterator, List, Mapping, Optional +from funcy import group_by from scmrepo.base import Base # noqa: F401, pylint: disable=unused-import from scmrepo.git import Git from scmrepo.noscm import NoSCM -from dvc.exceptions import DvcException +from dvc.exceptions import DvcException, InvalidArgumentError from dvc.progress import Tqdm if TYPE_CHECKING: @@ -123,3 +125,56 @@ def resolve_rev(scm: "Git", rev: str) -> str: if len(ref_infos) > 1: raise RevError(f"ambiguous Git revision '{rev}'") raise RevError(str(exc)) + + +def iter_revs( + scm: "Git", + head_revs: Optional[List[str]] = None, + num: int = 1, + all_branches: bool = False, + all_tags: bool = False, + all_commits: bool = False, + all_experiments: bool = False, +) -> Mapping[str, List[str]]: + from dvc.repo.experiments.utils import fix_exp_head + + if num < 1 and num != -1: + raise InvalidArgumentError(f"Invalid number of commits '{num}'") + + if not any( + [head_revs, all_branches, all_tags, all_commits, all_experiments] + ): + return {} + + head_revs = head_revs or [] + revs = [] + for rev in head_revs: + revs.append(rev) + n = 1 + while True: + if num == n: + break + try: + head = fix_exp_head(scm, f"{rev}~{n}") + assert head + revs.append(resolve_rev(scm, head)) + except RevError: + break + n += 1 + + if all_commits: + revs.extend(scm.list_all_commits()) + else: + if all_branches: + revs.extend(scm.list_branches()) + + if all_tags: + revs.extend(scm.list_tags()) + + if all_experiments: + from dvc.repo.experiments.utils import exp_commits + + revs.extend(exp_commits(scm)) + + rev_resolver = partial(resolve_rev, scm) + return group_by(rev_resolver, revs) diff --git a/tests/func/experiments/test_show.py b/tests/func/experiments/test_show.py index e45ef58993..470780d15b 100644 --- a/tests/func/experiments/test_show.py +++ b/tests/func/experiments/test_show.py @@ -351,8 +351,9 @@ def test_show_multiple_commits(tmp_dir, scm, dvc, exp_stage): tmp_dir.scm_gen("file", "file", "commit") next_rev = scm.get_rev() + dvc.experiments.show(num=-1) with pytest.raises(InvalidArgumentError): - dvc.experiments.show(num=-1) + dvc.experiments.show(num=-2) expected = {"workspace", init_rev, next_rev} results = dvc.experiments.show(num=2) diff --git a/tests/unit/command/test_experiments.py b/tests/unit/command/test_experiments.py index 1257910343..a3d2c8f434 100644 --- a/tests/unit/command/test_experiments.py +++ b/tests/unit/command/test_experiments.py @@ -99,6 +99,8 @@ def test_experiments_show(dvc, scm, mocker): "--param-deps", "-n", "1", + "--rev", + "foo", ] ) assert cli_args.func == CmdExperimentsShow @@ -113,8 +115,9 @@ def test_experiments_show(dvc, scm, mocker): all_tags=True, all_branches=True, all_commits=True, - sha_only=True, num=1, + revs="foo", + sha_only=True, param_deps=True, ) diff --git a/tests/unit/scm/test_scm.py b/tests/unit/scm/test_scm.py new file mode 100644 index 0000000000..338cc74ce4 --- /dev/null +++ b/tests/unit/scm/test_scm.py @@ -0,0 +1,63 @@ +from dvc.repo.experiments import ExpRefInfo +from dvc.scm import iter_revs + + +def test_iter_revs( + tmp_dir, + scm, +): + """ + new other + │ │ + old (tag) ────┘ + │ + root + """ + old = scm.active_branch() + tmp_dir.scm_gen("foo", "init", commit="init") + rev_root = scm.get_rev() + tmp_dir.scm_gen("foo", "old", commit="old") + rev_old = scm.get_rev() + scm.checkout("new", create_new=True) + tmp_dir.scm_gen("foo", "new", commit="new") + rev_new = scm.get_rev() + scm.checkout(old) + scm.tag("tag") + scm.checkout("other", create_new=True) + tmp_dir.scm_gen("foo", "other", commit="new") + rev_other = scm.get_rev() + + ref = ExpRefInfo(rev_root, "exp1") + scm.set_ref(str(ref), rev_new) + ref = ExpRefInfo(rev_root, "exp2") + scm.set_ref(str(ref), rev_old) + + gen = iter_revs(scm, [rev_root, "new"], 1) + assert gen == {rev_root: [rev_root], rev_new: ["new"]} + gen = iter_revs(scm, ["new"], 2) + assert gen == {rev_new: ["new"], rev_old: [rev_old]} + gen = iter_revs(scm, ["other"], -1) + assert gen == { + rev_other: ["other"], + rev_old: [rev_old], + rev_root: [rev_root], + } + gen = iter_revs(scm, ["tag"]) + assert gen == {rev_old: ["tag"]} + gen = iter_revs(scm, all_branches=True) + assert gen == {rev_old: [old], rev_new: ["new"], rev_other: ["other"]} + gen = iter_revs(scm, all_tags=True) + assert gen == {rev_old: ["tag"]} + gen = iter_revs(scm, all_commits=True) + assert gen == { + rev_old: [rev_old], + rev_new: [rev_new], + rev_other: [rev_other], + rev_root: [rev_root], + } + gen = iter_revs(scm, all_experiments=True) + assert gen == { + rev_new: [rev_new], + rev_old: [rev_old], + rev_root: [rev_root], + }