diff --git a/dvc/command/experiments/pull.py b/dvc/command/experiments/pull.py index b84b10083d..45af21a60c 100644 --- a/dvc/command/experiments/pull.py +++ b/dvc/command/experiments/pull.py @@ -87,6 +87,9 @@ def add_parser(experiments_subparsers, parent_parser): metavar="", ) experiments_pull_parser.add_argument( - "experiment", help="Experiment to pull.", metavar="" + "experiment", + nargs="+", + help="Experiments to pull.", + metavar="", ) experiments_pull_parser.set_defaults(func=CmdExperimentsPull) diff --git a/dvc/command/experiments/push.py b/dvc/command/experiments/push.py index b596b5c0b8..73b6becd94 100644 --- a/dvc/command/experiments/push.py +++ b/dvc/command/experiments/push.py @@ -88,6 +88,9 @@ def add_parser(experiments_subparsers, parent_parser): metavar="", ) experiments_push_parser.add_argument( - "experiment", help="Experiment to push.", metavar="" + "experiment", + nargs="+", + help="Experiments to push.", + metavar="", ).complete = completion.EXPERIMENT experiments_push_parser.set_defaults(func=CmdExperimentsPush) diff --git a/dvc/repo/experiments/exceptions.py b/dvc/repo/experiments/exceptions.py new file mode 100644 index 0000000000..d6a7a304e1 --- /dev/null +++ b/dvc/repo/experiments/exceptions.py @@ -0,0 +1,43 @@ +from typing import Iterable, List + +from dvc.exceptions import InvalidArgumentError + +from .base import ExpRefInfo + + +class AmbiguousExpRefInfo(InvalidArgumentError): + def __init__( + self, + exp_name: str, + exp_ref_list: Iterable[ExpRefInfo], + ): + msg = [ + ( + f"Ambiguous name '{exp_name}' refers to multiple experiments." + " Use one of the following full refnames instead:" + ), + "", + ] + msg.extend([f"\t{info}" for info in exp_ref_list]) + super().__init__("\n".join(msg)) + + +class UnresolvedExpNamesError(InvalidArgumentError): + def __init__( + self, unresolved_list: List[str], *args, git_remote: str = None + ): + unresolved_names = ";".join(unresolved_list) + if not git_remote: + if len(unresolved_names) > 1: + super().__init__( + f"'{unresolved_names}' are not valid experiment names" + ) + else: + super().__init__( + f"'{unresolved_names}' is not a valid experiment name" + ) + else: + super().__init__( + f"Experiment '{unresolved_names}' does not exist " + f"in '{git_remote}'" + ) diff --git a/dvc/repo/experiments/pull.py b/dvc/repo/experiments/pull.py index 8760e523e4..78b0442db4 100644 --- a/dvc/repo/experiments/pull.py +++ b/dvc/repo/experiments/pull.py @@ -1,9 +1,12 @@ import logging +from typing import Iterable, Union -from dvc.exceptions import DvcException, InvalidArgumentError +from dvc.exceptions import DvcException from dvc.repo import locked from dvc.repo.scm_context import scm_context +from dvc.scm import TqdmGit +from .exceptions import UnresolvedExpNamesError from .utils import exp_commits, resolve_name logger = logging.getLogger(__name__) @@ -12,40 +15,61 @@ @locked @scm_context def pull( - repo, git_remote, exp_name, *args, force=False, pull_cache=False, **kwargs + repo, + git_remote: str, + exp_names: Union[Iterable[str], str], + *args, + force: bool = False, + pull_cache: bool = False, + **kwargs, ): - exp_ref_dict = resolve_name(repo.scm, exp_name, git_remote) - exp_ref = exp_ref_dict[exp_name] - if not exp_ref: - raise InvalidArgumentError( - f"Experiment '{exp_name}' does not exist in '{git_remote}'" - ) + if isinstance(exp_names, str): + exp_names = [exp_names] + exp_ref_dict = resolve_name(repo.scm, exp_names, git_remote) + unresolved_exp_names = [ + exp_name + for exp_name, exp_ref in exp_ref_dict.items() + if exp_ref is None + ] + if unresolved_exp_names: + raise UnresolvedExpNamesError(unresolved_exp_names) + + exp_ref_set = exp_ref_dict.values() + _pull(repo, git_remote, exp_ref_set, force, pull_cache, **kwargs) + +def _pull( + repo, + git_remote: str, + exp_refs, + force: bool, + pull_cache: bool, + **kwargs, +): def on_diverged(refname: str, rev: str) -> bool: if repo.scm.get_ref(refname) == rev: return True + exp_name = refname.split("/")[-1] raise DvcException( f"Local experiment '{exp_name}' has diverged from remote " "experiment with the same name. To override the local experiment " "re-run with '--force'." ) - refspec = f"{exp_ref}:{exp_ref}" - logger.debug("git pull experiment '%s' -> '%s'", git_remote, refspec) - - from dvc.scm import TqdmGit + refspec_list = [f"{exp_ref}:{exp_ref}" for exp_ref in exp_refs] + logger.debug(f"git pull experiment '{git_remote}' -> '{refspec_list}'") with TqdmGit(desc="Fetching git refs") as pbar: repo.scm.fetch_refspecs( git_remote, - [refspec], + refspec_list, force=force, on_diverged=on_diverged, progress=pbar.update_git, ) if pull_cache: - _pull_cache(repo, exp_ref, **kwargs) + _pull_cache(repo, exp_refs, **kwargs) def _pull_cache( @@ -56,8 +80,8 @@ def _pull_cache( run_cache=False, odb=None, ): - revs = list(exp_commits(repo.scm, [exp_ref])) - logger.debug("dvc fetch experiment '%s'", exp_ref) + revs = list(exp_commits(repo.scm, exp_ref)) + logger.debug(f"dvc fetch experiment '{exp_ref}'") repo.fetch( jobs=jobs, remote=dvc_remote, run_cache=run_cache, revs=revs, odb=odb ) diff --git a/dvc/repo/experiments/push.py b/dvc/repo/experiments/push.py index b0b43e4828..6666c8c840 100644 --- a/dvc/repo/experiments/push.py +++ b/dvc/repo/experiments/push.py @@ -1,9 +1,12 @@ import logging +from typing import Iterable, Union -from dvc.exceptions import DvcException, InvalidArgumentError +from dvc.exceptions import DvcException from dvc.repo import locked from dvc.repo.scm_context import scm_context +from dvc.scm import TqdmGit +from .exceptions import UnresolvedExpNamesError from .utils import exp_commits, push_refspec, resolve_name logger = logging.getLogger(__name__) @@ -13,50 +16,66 @@ @scm_context def push( repo, - git_remote, - exp_name: str, + git_remote: str, + exp_names: Union[Iterable[str], str], *args, - force=False, - push_cache=False, + force: bool = False, + push_cache: bool = False, **kwargs, ): - exp_ref_dict = resolve_name(repo.scm, exp_name) - exp_ref = exp_ref_dict[exp_name] - if not exp_ref: - raise InvalidArgumentError( - f"'{exp_name}' is not a valid experiment name" - ) + if isinstance(exp_names, str): + exp_names = [exp_names] + + exp_ref_dict = resolve_name(repo.scm, exp_names) + unresolved_exp_names = [ + exp_name + for exp_name, exp_ref in exp_ref_dict.items() + if exp_ref is None + ] + if unresolved_exp_names: + raise UnresolvedExpNamesError(unresolved_exp_names) + + exp_ref_set = exp_ref_dict.values() + _push(repo, git_remote, exp_ref_set, force, push_cache, **kwargs) + +def _push( + repo, + git_remote: str, + exp_refs, + force: bool, + push_cache: bool, + **kwargs, +): def on_diverged(refname: str, rev: str) -> bool: if repo.scm.get_ref(refname) == rev: return True + exp_name = refname.split("/")[-1] raise DvcException( f"Local experiment '{exp_name}' has diverged from remote " "experiment with the same name. To override the remote experiment " "re-run with '--force'." ) - refname = str(exp_ref) - logger.debug("git push experiment '%s' -> '%s'", exp_ref, git_remote) + logger.debug(f"git push experiment '{exp_refs}' -> '{git_remote}'") - from dvc.scm import TqdmGit - - with TqdmGit(desc="Pushing git refs") as pbar: - push_refspec( - repo.scm, - git_remote, - refname, - refname, - force=force, - on_diverged=on_diverged, - progress=pbar.update_git, - ) + for exp_ref in exp_refs: + with TqdmGit(desc="Pushing git refs") as pbar: + push_refspec( + repo.scm, + git_remote, + str(exp_ref), + str(exp_ref), + force=force, + on_diverged=on_diverged, + progress=pbar.update_git, + ) if push_cache: - _push_cache(repo, exp_ref, **kwargs) + _push_cache(repo, exp_refs, **kwargs) -def _push_cache(repo, exp_ref, dvc_remote=None, jobs=None, run_cache=False): - revs = list(exp_commits(repo.scm, [exp_ref])) - logger.debug("dvc push experiment '%s'", exp_ref) +def _push_cache(repo, exp_refs, dvc_remote=None, jobs=None, run_cache=False): + revs = list(exp_commits(repo.scm, exp_refs)) + logger.debug(f"dvc push experiment '{exp_refs}'") repo.push(jobs=jobs, remote=dvc_remote, run_cache=run_cache, revs=revs) diff --git a/dvc/repo/experiments/utils.py b/dvc/repo/experiments/utils.py index c645c32242..9abf301c1f 100644 --- a/dvc/repo/experiments/utils.py +++ b/dvc/repo/experiments/utils.py @@ -13,6 +13,7 @@ from scmrepo.git import Git from dvc.exceptions import InvalidArgumentError +from dvc.repo.experiments.exceptions import AmbiguousExpRefInfo from .base import ( EXEC_BASELINE, @@ -23,23 +24,6 @@ ) -class AmbiguousExpRefInfo(InvalidArgumentError): - def __init__( - self, - exp_name: str, - exp_ref_list: Iterable[ExpRefInfo], - ): - msg = [ - ( - f"Ambiguous name '{exp_name}' refers to multiple experiments." - " Use one of the following full refnames instead:" - ), - "", - ] - msg.extend([f"\t{info}" for info in exp_ref_list]) - super().__init__("\n".join(msg)) - - def exp_refs(scm: "Git") -> Generator["ExpRefInfo", None, None]: """Iterate over all experiment refs.""" for ref in scm.iter_refs(base=EXPS_NAMESPACE): diff --git a/tests/func/experiments/test_remote.py b/tests/func/experiments/test_remote.py index 47bce63e8a..cc697eb968 100644 --- a/tests/func/experiments/test_remote.py +++ b/tests/func/experiments/test_remote.py @@ -13,19 +13,30 @@ def test_push(tmp_dir, scm, dvc, git_upstream, exp_stage, use_url): remote = git_upstream.url if use_url else git_upstream.remote with pytest.raises(InvalidArgumentError): - dvc.experiments.push(remote, "foo") + dvc.experiments.push(remote, ["foo"]) + + results = dvc.experiments.run(exp_stage.addressing, params=["foo=1"]) + exp1 = first(results) + ref_info1 = first(exp_refs_by_rev(scm, exp1)) results = dvc.experiments.run(exp_stage.addressing, params=["foo=2"]) - exp = first(results) - ref_info = first(exp_refs_by_rev(scm, exp)) + exp2 = first(results) + ref_info2 = first(exp_refs_by_rev(scm, exp2)) - dvc.experiments.push(remote, ref_info.name) - assert git_upstream.tmp_dir.scm.get_ref(str(ref_info)) == exp + results = dvc.experiments.run(exp_stage.addressing, params=["foo=3"]) + exp3 = first(results) + ref_info3 = first(exp_refs_by_rev(scm, exp3)) - git_upstream.tmp_dir.scm.remove_ref(str(ref_info)) + dvc.experiments.push(remote, [ref_info1.name, ref_info2.name]) + assert git_upstream.tmp_dir.scm.get_ref(str(ref_info1)) == exp1 + assert git_upstream.tmp_dir.scm.get_ref(str(ref_info2)) == exp2 + assert git_upstream.tmp_dir.scm.get_ref(str(ref_info3)) is None - dvc.experiments.push(remote, str(ref_info)) - assert git_upstream.tmp_dir.scm.get_ref(str(ref_info)) == exp + git_upstream.tmp_dir.scm.remove_ref(str(ref_info1)) + assert git_upstream.tmp_dir.scm.get_ref(str(ref_info1)) is None + + dvc.experiments.push(remote, [ref_info1.name]) + assert git_upstream.tmp_dir.scm.get_ref(str(ref_info1)) == exp1 def test_push_diverged(tmp_dir, scm, dvc, git_upstream, exp_stage): @@ -39,10 +50,10 @@ def test_push_diverged(tmp_dir, scm, dvc, git_upstream, exp_stage): git_upstream.tmp_dir.scm.set_ref(str(ref_info), remote_rev) with pytest.raises(DvcException): - dvc.experiments.push(git_upstream.remote, ref_info.name) + dvc.experiments.push(git_upstream.remote, [ref_info.name]) assert git_upstream.tmp_dir.scm.get_ref(str(ref_info)) == remote_rev - dvc.experiments.push(git_upstream.remote, ref_info.name, force=True) + dvc.experiments.push(git_upstream.remote, [ref_info.name], force=True) assert git_upstream.tmp_dir.scm.get_ref(str(ref_info)) == exp @@ -53,7 +64,7 @@ def test_push_checkpoint(tmp_dir, scm, dvc, git_upstream, checkpoint_stage): exp_a = first(results) ref_info_a = first(exp_refs_by_rev(scm, exp_a)) - dvc.experiments.push(git_upstream.remote, ref_info_a.name, force=True) + dvc.experiments.push(git_upstream.remote, [ref_info_a.name], force=True) assert git_upstream.tmp_dir.scm.get_ref(str(ref_info_a)) == exp_a results = dvc.experiments.run( @@ -64,7 +75,7 @@ def test_push_checkpoint(tmp_dir, scm, dvc, git_upstream, checkpoint_stage): tmp_dir.scm_gen("new", "new", commit="new") - dvc.experiments.push(git_upstream.remote, ref_info_b.name, force=True) + dvc.experiments.push(git_upstream.remote, [ref_info_b.name], force=True) assert git_upstream.tmp_dir.scm.get_ref(str(ref_info_b)) == exp_b @@ -86,15 +97,15 @@ def test_push_ambiguous_name(tmp_dir, scm, dvc, git_upstream, exp_stage): exp_b = first(results) ref_info_b = first(exp_refs_by_rev(scm, exp_b)) - dvc.experiments.push(remote, "foo") + dvc.experiments.push(remote, ["foo"]) assert git_upstream.tmp_dir.scm.get_ref(str(ref_info_b)) == exp_b tmp_dir.scm_gen("new", "new 2", commit="new 2") with pytest.raises(InvalidArgumentError): - dvc.experiments.push(remote, "foo") + dvc.experiments.push(remote, ["foo"]) - dvc.experiments.push(remote, str(ref_info_a)) + dvc.experiments.push(remote, [str(ref_info_a)]) assert git_upstream.tmp_dir.scm.get_ref(str(ref_info_a)) == exp_a @@ -140,19 +151,29 @@ def test_pull(tmp_dir, scm, dvc, git_downstream, exp_stage, use_url): remote = git_downstream.url if use_url else git_downstream.remote downstream_exp = git_downstream.tmp_dir.dvc.experiments with pytest.raises(InvalidArgumentError): - downstream_exp.pull(remote, "foo") + downstream_exp.pull(remote, ["foo"]) + results = dvc.experiments.run(exp_stage.addressing, params=["foo=1"]) + exp1 = first(results) + ref_info1 = first(exp_refs_by_rev(scm, exp1)) results = dvc.experiments.run(exp_stage.addressing, params=["foo=2"]) - exp = first(results) - ref_info = first(exp_refs_by_rev(scm, exp)) + exp2 = first(results) + ref_info2 = first(exp_refs_by_rev(scm, exp2)) + results = dvc.experiments.run(exp_stage.addressing, params=["foo=3"]) + exp3 = first(results) + ref_info3 = first(exp_refs_by_rev(scm, exp3)) - downstream_exp.pull(remote, ref_info.name) - assert git_downstream.tmp_dir.scm.get_ref(str(ref_info)) == exp + downstream_exp.pull( + git_downstream.remote, [ref_info1.name, ref_info2.name], force=True + ) + assert git_downstream.tmp_dir.scm.get_ref(str(ref_info1)) == exp1 + assert git_downstream.tmp_dir.scm.get_ref(str(ref_info2)) == exp2 + assert git_downstream.tmp_dir.scm.get_ref(str(ref_info3)) is None - git_downstream.tmp_dir.scm.remove_ref(str(ref_info)) + git_downstream.tmp_dir.scm.remove_ref(str(ref_info1)) - downstream_exp.pull(remote, str(ref_info)) - assert git_downstream.tmp_dir.scm.get_ref(str(ref_info)) == exp + downstream_exp.pull(remote, [str(ref_info1)]) + assert git_downstream.tmp_dir.scm.get_ref(str(ref_info1)) == exp1 def test_pull_diverged(tmp_dir, scm, dvc, git_downstream, exp_stage): @@ -182,7 +203,7 @@ def test_pull_checkpoint(tmp_dir, scm, dvc, git_downstream, checkpoint_stage): ref_info_a = first(exp_refs_by_rev(scm, exp_a)) downstream_exp = git_downstream.tmp_dir.dvc.experiments - downstream_exp.pull(git_downstream.remote, ref_info_a.name, force=True) + downstream_exp.pull(git_downstream.remote, [ref_info_a.name], force=True) assert git_downstream.tmp_dir.scm.get_ref(str(ref_info_a)) == exp_a results = dvc.experiments.run( @@ -191,7 +212,7 @@ def test_pull_checkpoint(tmp_dir, scm, dvc, git_downstream, checkpoint_stage): exp_b = first(results) ref_info_b = first(exp_refs_by_rev(scm, exp_b)) - downstream_exp.pull(git_downstream.remote, ref_info_b.name, force=True) + downstream_exp.pull(git_downstream.remote, [ref_info_b.name], force=True) assert git_downstream.tmp_dir.scm.get_ref(str(ref_info_b)) == exp_b @@ -214,13 +235,13 @@ def test_pull_ambiguous_name(tmp_dir, scm, dvc, git_downstream, exp_stage): remote = git_downstream.remote downstream_exp = git_downstream.tmp_dir.dvc.experiments with pytest.raises(InvalidArgumentError): - downstream_exp.pull(remote, "foo") + downstream_exp.pull(remote, ["foo"]) - downstream_exp.pull(remote, str(ref_info_b)) + downstream_exp.pull(remote, [str(ref_info_b)]) assert git_downstream.tmp_dir.scm.get_ref(str(ref_info_b)) == exp_b with git_downstream.tmp_dir.scm.detach_head(ref_info_a.baseline_sha): - downstream_exp.pull(remote, "foo") + downstream_exp.pull(remote, ["foo"]) assert git_downstream.tmp_dir.scm.get_ref(str(ref_info_a)) == exp_a @@ -237,7 +258,7 @@ def test_push_pull_cache( exp = first(results) ref_info = first(exp_refs_by_rev(scm, exp)) - dvc.experiments.push(remote, ref_info.name, push_cache=True) + dvc.experiments.push(remote, [ref_info.name], push_cache=True) for x in range(2, checkpoint_stage.iterations + 1): hash_ = digest(str(x)) path = os.path.join(local_remote.url, hash_[:2], hash_[2:]) @@ -246,7 +267,7 @@ def test_push_pull_cache( remove(dvc.odb.local.cache_dir) - dvc.experiments.pull(remote, ref_info.name, pull_cache=True) + dvc.experiments.pull(remote, [ref_info.name], pull_cache=True) for x in range(2, checkpoint_stage.iterations + 1): hash_ = digest(str(x)) path = os.path.join(dvc.odb.local.cache_dir, hash_[:2], hash_[2:]) @@ -271,7 +292,7 @@ def test_auth_error_pull(tmp_dir, scm, dvc, http_auth_patch): GitAuthError, match=f"HTTP Git authentication is not supported: '{http_auth_patch}'", ): - dvc.experiments.pull(http_auth_patch, "foo") + dvc.experiments.pull(http_auth_patch, ["foo"]) def test_auth_error_push(tmp_dir, scm, dvc, exp_stage, http_auth_patch): @@ -285,4 +306,4 @@ def test_auth_error_push(tmp_dir, scm, dvc, exp_stage, http_auth_patch): GitAuthError, match=f"HTTP Git authentication is not supported: '{http_auth_patch}'", ): - dvc.experiments.push(http_auth_patch, ref_info.name) + dvc.experiments.push(http_auth_patch, [ref_info.name]) diff --git a/tests/func/experiments/test_remove.py b/tests/func/experiments/test_remove.py index 30fd7d9e90..d51cbcca44 100644 --- a/tests/func/experiments/test_remove.py +++ b/tests/func/experiments/test_remove.py @@ -106,7 +106,7 @@ def test_remove_remote(tmp_dir, scm, dvc, exp_stage, git_upstream, use_url): exp_list.append(exp) ref_info = first(exp_refs_by_rev(scm, exp)) ref_info_list.append(ref_info) - dvc.experiments.push(remote, ref_info.name) + dvc.experiments.push(remote, [ref_info.name]) assert git_upstream.tmp_dir.scm.get_ref(str(ref_info)) == exp dvc.experiments.remove( diff --git a/tests/unit/command/test_experiments.py b/tests/unit/command/test_experiments.py index dbc21c50c5..2ff01914e8 100644 --- a/tests/unit/command/test_experiments.py +++ b/tests/unit/command/test_experiments.py @@ -221,7 +221,8 @@ def test_experiments_push(dvc, scm, mocker): "experiments", "push", "origin", - "experiment", + "experiment1", + "experiment2", "--force", "--no-cache", "--remote", @@ -241,7 +242,7 @@ def test_experiments_push(dvc, scm, mocker): m.assert_called_once_with( cmd.repo, "origin", - "experiment", + ["experiment1", "experiment2"], force=True, push_cache=False, dvc_remote="my-remote", @@ -276,7 +277,7 @@ def test_experiments_pull(dvc, scm, mocker): m.assert_called_once_with( cmd.repo, "origin", - "experiment", + ["experiment"], force=True, pull_cache=False, dvc_remote="my-remote",