From 5106d46adb79ca0659e35db5ad5725e77de2a886 Mon Sep 17 00:00:00 2001 From: karajan1001 Date: Wed, 12 Jan 2022 15:15:42 +0800 Subject: [PATCH] exp push/pull: add `--all-commits` flag and unify the collection of revs (#7154) fix: #7154 1. rename flags `-A/--all-commits` in the `exp pull/push` 2. add new flag "--rev" and "--num" in the `exp pull/push` 3. Unify the collection of revs in `exp push/pull` 4. add unit and func tests for `exp push/pull` --- dvc/commands/experiments/pull.py | 33 ++++++++-- dvc/commands/experiments/push.py | 34 ++++++++-- dvc/repo/experiments/pull.py | 71 +++++++++++++-------- dvc/repo/experiments/push.py | 86 +++++++++++++++++--------- tests/func/experiments/test_remote.py | 66 ++++++++++++++++++++ tests/unit/command/test_experiments.py | 46 ++++++++++++++ 6 files changed, 269 insertions(+), 67 deletions(-) diff --git a/dvc/commands/experiments/pull.py b/dvc/commands/experiments/pull.py index 3cd7e0ea89..9709dcfd11 100644 --- a/dvc/commands/experiments/pull.py +++ b/dvc/commands/experiments/pull.py @@ -3,17 +3,31 @@ from dvc.cli.command import CmdBase from dvc.cli.utils import append_doc_link +from dvc.exceptions import InvalidArgumentError from dvc.ui import ui logger = logging.getLogger(__name__) class CmdExperimentsPull(CmdBase): + def raise_error_if_all_disabled(self): + if not any( + [self.args.experiment, self.args.all_commits, self.args.rev] + ): + raise InvalidArgumentError( + "Either provide an `experiment` argument, or use the " + "`--rev` or `--all-commits` flag." + ) + def run(self): + self.raise_error_if_all_disabled() - self.repo.experiments.pull( + pulled_exps = self.repo.experiments.pull( self.args.git_remote, self.args.experiment, + all_commits=self.args.all_commits, + rev=self.args.rev, + num=self.args.num, force=self.args.force, pull_cache=self.args.pull_cache, dvc_remote=self.args.dvc_remote, @@ -21,10 +35,13 @@ def run(self): run_cache=self.args.run_cache, ) - ui.write( - f"Pulled experiment '{self.args.experiment}'", - f"from Git remote '{self.args.git_remote}'.", - ) + if pulled_exps: + ui.write( + f"Pulled experiment '{pulled_exps}'", + f"from Git remote '{self.args.git_remote}'.", + ) + else: + ui.write("No experiments to pull.") if not self.args.pull_cache: ui.write( "To pull cached outputs for this experiment" @@ -36,6 +53,8 @@ def run(self): def add_parser(experiments_subparsers, parent_parser): + from . import add_rev_selection_flags + EXPERIMENTS_PULL_HELP = "Pull an experiment from a Git remote." experiments_pull_parser = experiments_subparsers.add_parser( "pull", @@ -44,6 +63,7 @@ def add_parser(experiments_subparsers, parent_parser): help=EXPERIMENTS_PULL_HELP, formatter_class=argparse.RawDescriptionHelpFormatter, ) + add_rev_selection_flags(experiments_pull_parser, "Pull", False) experiments_pull_parser.add_argument( "-f", "--force", @@ -89,7 +109,8 @@ def add_parser(experiments_subparsers, parent_parser): ) experiments_pull_parser.add_argument( "experiment", - nargs="+", + nargs="*", + default=None, help="Experiments to pull.", metavar="", ) diff --git a/dvc/commands/experiments/push.py b/dvc/commands/experiments/push.py index 15d6e95cae..1cd2467a38 100644 --- a/dvc/commands/experiments/push.py +++ b/dvc/commands/experiments/push.py @@ -4,17 +4,32 @@ from dvc.cli.command import CmdBase from dvc.cli.utils import append_doc_link from dvc.commands import completion +from dvc.exceptions import InvalidArgumentError from dvc.ui import ui logger = logging.getLogger(__name__) class CmdExperimentsPush(CmdBase): + def raise_error_if_all_disabled(self): + if not any( + [self.args.experiment, self.args.all_commits, self.args.rev] + ): + raise InvalidArgumentError( + "Either provide an `experiment` argument, or use the " + "`--rev` or `--all-commits` flag." + ) + def run(self): - self.repo.experiments.push( + self.raise_error_if_all_disabled() + + pushed_exps = self.repo.experiments.push( self.args.git_remote, self.args.experiment, + all_commits=self.args.all_commits, + rev=self.args.rev, + num=self.args.num, force=self.args.force, push_cache=self.args.push_cache, dvc_remote=self.args.dvc_remote, @@ -22,10 +37,13 @@ def run(self): run_cache=self.args.run_cache, ) - ui.write( - f"Pushed experiment '{self.args.experiment}'" - f"to Git remote '{self.args.git_remote}'." - ) + if pushed_exps: + ui.write( + f"Pushed experiment '{pushed_exps}'" + f"to Git remote '{self.args.git_remote}'." + ) + else: + ui.write("No experiments to pull.") if not self.args.push_cache: ui.write( "To push cached outputs", @@ -37,6 +55,8 @@ def run(self): def add_parser(experiments_subparsers, parent_parser): + from . import add_rev_selection_flags + EXPERIMENTS_PUSH_HELP = "Push a local experiment to a Git remote." experiments_push_parser = experiments_subparsers.add_parser( "push", @@ -45,6 +65,7 @@ def add_parser(experiments_subparsers, parent_parser): help=EXPERIMENTS_PUSH_HELP, formatter_class=argparse.RawDescriptionHelpFormatter, ) + add_rev_selection_flags(experiments_push_parser, "Push", False) experiments_push_parser.add_argument( "-f", "--force", @@ -90,7 +111,8 @@ def add_parser(experiments_subparsers, parent_parser): ) experiments_push_parser.add_argument( "experiment", - nargs="+", + nargs="*", + default=None, help="Experiments to push.", metavar="", ).complete = completion.EXPERIMENT diff --git a/dvc/repo/experiments/pull.py b/dvc/repo/experiments/pull.py index 78b0442db4..488e6a3cae 100644 --- a/dvc/repo/experiments/pull.py +++ b/dvc/repo/experiments/pull.py @@ -1,13 +1,14 @@ import logging -from typing import Iterable, Union +from typing import Iterable, Optional, Set, Union from dvc.exceptions import DvcException from dvc.repo import locked from dvc.repo.scm_context import scm_context -from dvc.scm import TqdmGit +from dvc.scm import TqdmGit, iter_revs +from .base import ExpRefInfo from .exceptions import UnresolvedExpNamesError -from .utils import exp_commits, resolve_name +from .utils import exp_commits, exp_refs, exp_refs_by_baseline, resolve_name logger = logging.getLogger(__name__) @@ -18,33 +19,50 @@ def pull( repo, git_remote: str, exp_names: Union[Iterable[str], str], - *args, + all_commits=False, + rev: Optional[str] = None, + num=1, force: bool = False, pull_cache: bool = False, **kwargs, -): - if isinstance(exp_names, str): - exp_names = [exp_names] - exp_ref_dict = resolve_name(repo.scm, exp_names, git_remote) - unresolved_exp_names = [ - exp_name - for exp_name, exp_ref in exp_ref_dict.items() - if exp_ref is None - ] - if unresolved_exp_names: - raise UnresolvedExpNamesError(unresolved_exp_names) +) -> Iterable[str]: + exp_ref_set: Set["ExpRefInfo"] = set() + if all_commits: + exp_ref_set.update(exp_refs(repo.scm, git_remote)) + else: + if exp_names: + if isinstance(exp_names, str): + exp_names = [exp_names] + exp_ref_dict = resolve_name(repo.scm, exp_names, git_remote) + + unresolved_exp_names = [] + for exp_name, exp_ref in exp_ref_dict.items(): + if exp_ref is None: + unresolved_exp_names.append(exp_name) + else: + exp_ref_set.add(exp_ref) + + if unresolved_exp_names: + raise UnresolvedExpNamesError(unresolved_exp_names) + + if rev: + rev_dict = iter_revs(repo.scm, [rev], num) + rev_set = set(rev_dict.keys()) + ref_info_dict = exp_refs_by_baseline(repo.scm, rev_set, git_remote) + for _, ref_info_list in ref_info_dict.items(): + exp_ref_set.update(ref_info_list) - exp_ref_set = exp_ref_dict.values() - _pull(repo, git_remote, exp_ref_set, force, pull_cache, **kwargs) + _pull(repo, git_remote, exp_ref_set, force) + if pull_cache: + _pull_cache(repo, exp_ref_set, **kwargs) + return [ref.name for ref in exp_ref_set] def _pull( repo, git_remote: str, - exp_refs, + refs, force: bool, - pull_cache: bool, - **kwargs, ): def on_diverged(refname: str, rev: str) -> bool: if repo.scm.get_ref(refname) == rev: @@ -56,7 +74,7 @@ def on_diverged(refname: str, rev: str) -> bool: "re-run with '--force'." ) - refspec_list = [f"{exp_ref}:{exp_ref}" for exp_ref in exp_refs] + refspec_list = [f"{exp_ref}:{exp_ref}" for exp_ref in refs] logger.debug(f"git pull experiment '{git_remote}' -> '{refspec_list}'") with TqdmGit(desc="Fetching git refs") as pbar: @@ -68,20 +86,19 @@ def on_diverged(refname: str, rev: str) -> bool: progress=pbar.update_git, ) - if pull_cache: - _pull_cache(repo, exp_refs, **kwargs) - def _pull_cache( repo, - exp_ref, + refs: Union[ExpRefInfo, Iterable["ExpRefInfo"]], dvc_remote=None, jobs=None, run_cache=False, odb=None, ): - revs = list(exp_commits(repo.scm, exp_ref)) - logger.debug(f"dvc fetch experiment '{exp_ref}'") + if isinstance(refs, ExpRefInfo): + refs = [refs] + revs = list(exp_commits(repo.scm, refs)) + logger.debug(f"dvc fetch experiment '{refs}'") repo.fetch( jobs=jobs, remote=dvc_remote, run_cache=run_cache, revs=revs, odb=odb ) diff --git a/dvc/repo/experiments/push.py b/dvc/repo/experiments/push.py index 6666c8c840..308822ecbb 100644 --- a/dvc/repo/experiments/push.py +++ b/dvc/repo/experiments/push.py @@ -1,13 +1,20 @@ import logging -from typing import Iterable, Union +from typing import Iterable, Optional, Set, Union from dvc.exceptions import DvcException from dvc.repo import locked from dvc.repo.scm_context import scm_context -from dvc.scm import TqdmGit +from dvc.scm import TqdmGit, iter_revs +from .base import ExpRefInfo from .exceptions import UnresolvedExpNamesError -from .utils import exp_commits, push_refspec, resolve_name +from .utils import ( + exp_commits, + exp_refs, + exp_refs_by_baseline, + push_refspec, + resolve_name, +) logger = logging.getLogger(__name__) @@ -18,34 +25,52 @@ def push( repo, git_remote: str, exp_names: Union[Iterable[str], str], - *args, + all_commits=False, + rev: Optional[str] = None, + num=1, force: bool = False, push_cache: bool = False, **kwargs, -): - if isinstance(exp_names, str): - exp_names = [exp_names] +) -> Iterable[str]: + + exp_ref_set: Set["ExpRefInfo"] = set() + if all_commits: + exp_ref_set.update(exp_refs(repo.scm)) + + else: + if exp_names: + if isinstance(exp_names, str): + exp_names = [exp_names] + exp_ref_dict = resolve_name(repo.scm, exp_names) - exp_ref_dict = resolve_name(repo.scm, exp_names) - unresolved_exp_names = [ - exp_name - for exp_name, exp_ref in exp_ref_dict.items() - if exp_ref is None - ] - if unresolved_exp_names: - raise UnresolvedExpNamesError(unresolved_exp_names) + unresolved_exp_names = [] + for exp_name, exp_ref in exp_ref_dict.items(): + if exp_ref is None: + unresolved_exp_names.append(exp_name) + else: + exp_ref_set.add(exp_ref) - exp_ref_set = exp_ref_dict.values() - _push(repo, git_remote, exp_ref_set, force, push_cache, **kwargs) + if unresolved_exp_names: + raise UnresolvedExpNamesError(unresolved_exp_names) + + if rev: + rev_dict = iter_revs(repo.scm, [rev], num) + rev_set = set(rev_dict.keys()) + ref_info_dict = exp_refs_by_baseline(repo.scm, rev_set) + for _, ref_info_list in ref_info_dict.items(): + exp_ref_set.update(ref_info_list) + + _push(repo, git_remote, exp_ref_set, force) + if push_cache: + _push_cache(repo, exp_ref_set, **kwargs) + return [ref.name for ref in exp_ref_set] def _push( repo, git_remote: str, - exp_refs, + refs: Iterable["ExpRefInfo"], force: bool, - push_cache: bool, - **kwargs, ): def on_diverged(refname: str, rev: str) -> bool: if repo.scm.get_ref(refname) == rev: @@ -57,9 +82,9 @@ def on_diverged(refname: str, rev: str) -> bool: "re-run with '--force'." ) - logger.debug(f"git push experiment '{exp_refs}' -> '{git_remote}'") + logger.debug(f"git push experiment '{refs}' -> '{git_remote}'") - for exp_ref in exp_refs: + for exp_ref in refs: with TqdmGit(desc="Pushing git refs") as pbar: push_refspec( repo.scm, @@ -71,11 +96,16 @@ def on_diverged(refname: str, rev: str) -> bool: progress=pbar.update_git, ) - if push_cache: - _push_cache(repo, exp_refs, **kwargs) - -def _push_cache(repo, exp_refs, dvc_remote=None, jobs=None, run_cache=False): - revs = list(exp_commits(repo.scm, exp_refs)) - logger.debug(f"dvc push experiment '{exp_refs}'") +def _push_cache( + repo, + refs: Union[ExpRefInfo, Iterable["ExpRefInfo"]], + dvc_remote=None, + jobs=None, + run_cache=False, +): + if isinstance(refs, ExpRefInfo): + refs = [refs] + revs = list(exp_commits(repo.scm, refs)) + logger.debug(f"dvc push experiment '{refs}'") repo.push(jobs=jobs, remote=dvc_remote, run_cache=run_cache, revs=revs) diff --git a/tests/func/experiments/test_remote.py b/tests/func/experiments/test_remote.py index 6d5b9a25f6..49acc88680 100644 --- a/tests/func/experiments/test_remote.py +++ b/tests/func/experiments/test_remote.py @@ -39,6 +39,38 @@ def test_push(tmp_dir, scm, dvc, git_upstream, exp_stage, use_url): assert git_upstream.tmp_dir.scm.get_ref(str(ref_info1)) == exp1 +@pytest.mark.parametrize( + "all_,rev,result3", [(True, False, True), (False, True, None)] +) +def test_push_args( + tmp_dir, scm, dvc, git_upstream, exp_stage, all_, rev, result3 +): + remote = git_upstream.url + baseline = scm.get_rev() + + results = dvc.experiments.run(exp_stage.addressing, params=["foo=1"]) + exp1 = first(results) + ref_info1 = first(exp_refs_by_rev(scm, exp1)) + results = dvc.experiments.run(exp_stage.addressing, params=["foo=2"]) + exp2 = first(results) + ref_info2 = first(exp_refs_by_rev(scm, exp2)) + + scm.commit("new_baseline") + + results = dvc.experiments.run(exp_stage.addressing, params=["foo=3"]) + exp3 = first(results) + ref_info3 = first(exp_refs_by_rev(scm, exp3)) + + if rev: + rev = baseline + dvc.experiments.push(remote, [], all_commits=all_, rev=rev) + assert git_upstream.tmp_dir.scm.get_ref(str(ref_info1)) == exp1 + assert git_upstream.tmp_dir.scm.get_ref(str(ref_info2)) == exp2 + if result3: + result3 = exp3 + assert git_upstream.tmp_dir.scm.get_ref(str(ref_info3)) == result3 + + def test_push_diverged(tmp_dir, scm, dvc, git_upstream, exp_stage): git_upstream.tmp_dir.scm_gen("foo", "foo", commit="init") remote_rev = git_upstream.tmp_dir.scm.get_rev() @@ -177,6 +209,40 @@ def test_pull(tmp_dir, scm, dvc, git_downstream, exp_stage, use_url): assert git_downstream.tmp_dir.scm.get_ref(str(ref_info1)) == exp1 +@pytest.mark.parametrize( + "all_,rev,result3", [(True, False, True), (False, True, None)] +) +def test_pull_args( + tmp_dir, scm, dvc, git_downstream, exp_stage, all_, rev, result3 +): + baseline = scm.get_rev() + + results = dvc.experiments.run(exp_stage.addressing, params=["foo=1"]) + exp1 = first(results) + ref_info1 = first(exp_refs_by_rev(scm, exp1)) + results = dvc.experiments.run(exp_stage.addressing, params=["foo=2"]) + exp2 = first(results) + ref_info2 = first(exp_refs_by_rev(scm, exp2)) + + scm.commit("new_baseline") + + results = dvc.experiments.run(exp_stage.addressing, params=["foo=3"]) + exp3 = first(results) + ref_info3 = first(exp_refs_by_rev(scm, exp3)) + + if rev: + rev = baseline + + downstream_exp = git_downstream.tmp_dir.dvc.experiments + git_downstream.tmp_dir.scm.fetch_refspecs(str(tmp_dir), ["master:master"]) + downstream_exp.pull(git_downstream.remote, [], all_commits=all_, rev=rev) + assert git_downstream.tmp_dir.scm.get_ref(str(ref_info1)) == exp1 + assert git_downstream.tmp_dir.scm.get_ref(str(ref_info2)) == exp2 + if result3: + result3 = exp3 + assert git_downstream.tmp_dir.scm.get_ref(str(ref_info3)) == result3 + + def test_pull_diverged(tmp_dir, scm, dvc, git_downstream, exp_stage): git_downstream.tmp_dir.scm_gen("foo", "foo", commit="init") remote_rev = git_downstream.tmp_dir.scm.get_rev() diff --git a/tests/unit/command/test_experiments.py b/tests/unit/command/test_experiments.py index 35fae979c7..7cc5c1fe4c 100644 --- a/tests/unit/command/test_experiments.py +++ b/tests/unit/command/test_experiments.py @@ -229,6 +229,11 @@ def test_experiments_push(dvc, scm, mocker): "origin", "experiment1", "experiment2", + "--all-commits", + "-n", + "2", + "--rev", + "foo", "--force", "--no-cache", "--remote", @@ -249,6 +254,9 @@ def test_experiments_push(dvc, scm, mocker): cmd.repo, "origin", ["experiment1", "experiment2"], + rev="foo", + all_commits=True, + num=2, force=True, push_cache=False, dvc_remote="my-remote", @@ -256,6 +264,22 @@ def test_experiments_push(dvc, scm, mocker): run_cache=True, ) + cli_args = parse_args( + [ + "experiments", + "push", + "origin", + ] + ) + cmd = cli_args.func(cli_args) + + with pytest.raises(InvalidArgumentError) as exp_info: + cmd.run() + assert ( + str(exp_info.value) == "Either provide an `experiment` argument" + ", or use the `--rev` or `--all-commits` flag." + ) + def test_experiments_pull(dvc, scm, mocker): cli_args = parse_args( @@ -264,6 +288,9 @@ def test_experiments_pull(dvc, scm, mocker): "pull", "origin", "experiment", + "--all-commits", + "--rev", + "foo", "--force", "--no-cache", "--remote", @@ -284,6 +311,9 @@ def test_experiments_pull(dvc, scm, mocker): cmd.repo, "origin", ["experiment"], + rev="foo", + all_commits=True, + num=1, force=True, pull_cache=False, dvc_remote="my-remote", @@ -291,6 +321,22 @@ def test_experiments_pull(dvc, scm, mocker): run_cache=True, ) + cli_args = parse_args( + [ + "experiments", + "pull", + "origin", + ] + ) + cmd = cli_args.func(cli_args) + + with pytest.raises(InvalidArgumentError) as exp_info: + cmd.run() + assert ( + str(exp_info.value) == "Either provide an `experiment` argument" + ", or use the `--rev` or `--all-commits` flag." + ) + @pytest.mark.parametrize( "queue,clear_all,remote",