diff --git a/dvc/commands/experiments/pull.py b/dvc/commands/experiments/pull.py index 98acc40ec9..3fab2d75ba 100644 --- a/dvc/commands/experiments/pull.py +++ b/dvc/commands/experiments/pull.py @@ -3,23 +3,13 @@ from dvc.cli.command import CmdBase from dvc.cli.utils import append_doc_link -from dvc.exceptions import InvalidArgumentError from dvc.ui import ui logger = logging.getLogger(__name__) class CmdExperimentsPull(CmdBase): - def raise_error_if_all_disabled(self): - if not any([self.args.experiment, self.args.all_commits, self.args.rev]): - raise InvalidArgumentError( - "Either provide an `experiment` argument, or use the " - "`--rev` or `--all-commits` flag." - ) - def run(self): - self.raise_error_if_all_disabled() - pulled_exps = self.repo.experiments.pull( self.args.git_remote, self.args.experiment, diff --git a/dvc/commands/experiments/push.py b/dvc/commands/experiments/push.py index fdd3834cad..2982710468 100644 --- a/dvc/commands/experiments/push.py +++ b/dvc/commands/experiments/push.py @@ -5,20 +5,12 @@ from dvc.cli import completion from dvc.cli.command import CmdBase from dvc.cli.utils import append_doc_link -from dvc.exceptions import InvalidArgumentError from dvc.ui import ui logger = logging.getLogger(__name__) class CmdExperimentsPush(CmdBase): - def raise_error_if_all_disabled(self): - if not any([self.args.experiment, self.args.all_commits, self.args.rev]): - raise InvalidArgumentError( - "Either provide an `experiment` argument, or use the " - "`--rev` or `--all-commits` flag." - ) - @staticmethod def log_result(result: Dict[str, Any], remote: str): from dvc.utils import humanize @@ -59,8 +51,6 @@ def join_exps(exps): def run(self): from dvc.repo.experiments.push import UploadError - self.raise_error_if_all_disabled() - try: result = self.repo.experiments.push( self.args.git_remote, diff --git a/dvc/repo/experiments/pull.py b/dvc/repo/experiments/pull.py index 2d4624c4f8..6841cfd4c3 100644 --- a/dvc/repo/experiments/pull.py +++ b/dvc/repo/experiments/pull.py @@ -21,7 +21,7 @@ def pull( # noqa: C901 repo, git_remote: str, - exp_names: Union[Iterable[str], str], + exp_names: Optional[Union[Iterable[str], str]] = None, all_commits=False, rev: Optional[Union[List[str], str]] = None, num=1, @@ -32,30 +32,30 @@ def pull( # noqa: C901 exp_ref_set: Set["ExpRefInfo"] = set() if all_commits: exp_ref_set.update(exp_refs(repo.scm, git_remote)) + elif exp_names: + if isinstance(exp_names, str): + exp_names = [exp_names] + exp_ref_dict = resolve_name(repo.scm, exp_names, git_remote) + + unresolved_exp_names = [] + for exp_name, exp_ref in exp_ref_dict.items(): + if exp_ref is None: + unresolved_exp_names.append(exp_name) + else: + exp_ref_set.add(exp_ref) + + if unresolved_exp_names: + raise UnresolvedExpNamesError(unresolved_exp_names) + else: - if exp_names: - if isinstance(exp_names, str): - exp_names = [exp_names] - exp_ref_dict = resolve_name(repo.scm, exp_names, git_remote) - - unresolved_exp_names = [] - for exp_name, exp_ref in exp_ref_dict.items(): - if exp_ref is None: - unresolved_exp_names.append(exp_name) - else: - exp_ref_set.add(exp_ref) - - if unresolved_exp_names: - raise UnresolvedExpNamesError(unresolved_exp_names) - - if rev: - if isinstance(rev, str): - rev = [rev] - rev_dict = iter_revs(repo.scm, rev, num) - rev_set = set(rev_dict.keys()) - ref_info_dict = exp_refs_by_baseline(repo.scm, rev_set, git_remote) - for _, ref_info_list in ref_info_dict.items(): - exp_ref_set.update(ref_info_list) + rev = rev or "HEAD" + if isinstance(rev, str): + rev = [rev] + rev_dict = iter_revs(repo.scm, rev, num) + rev_set = set(rev_dict.keys()) + ref_info_dict = exp_refs_by_baseline(repo.scm, rev_set, git_remote) + for _, ref_info_list in ref_info_dict.items(): + exp_ref_set.update(ref_info_list) pull_result = _pull(repo, git_remote, exp_ref_set, force) diff --git a/dvc/repo/experiments/push.py b/dvc/repo/experiments/push.py index 545c4e671f..f7b81baa20 100644 --- a/dvc/repo/experiments/push.py +++ b/dvc/repo/experiments/push.py @@ -97,7 +97,7 @@ def exp_refs_from_rev(scm: "Git", rev: List[str], num: int = 1) -> Set["ExpRefIn def push( repo: "Repo", git_remote: str, - exp_names: Union[List[str], str], + exp_names: Optional[Union[List[str], str]] = None, all_commits: bool = False, rev: Optional[Union[List[str], str]] = None, num: int = 1, @@ -111,7 +111,8 @@ def push( exp_ref_set.update(exp_refs(repo.scm)) if exp_names: exp_ref_set.update(exp_refs_from_names(repo.scm, ensure_list(exp_names))) - if rev: + else: + rev = rev or "HEAD" if isinstance(rev, str): rev = [rev] exp_ref_set.update(exp_refs_from_rev(repo.scm, rev, num=num)) diff --git a/tests/func/experiments/test_remote.py b/tests/func/experiments/test_remote.py index 22a901ec89..aab24f80c2 100644 --- a/tests/func/experiments/test_remote.py +++ b/tests/func/experiments/test_remote.py @@ -35,6 +35,9 @@ def test_push(tmp_dir, scm, dvc, git_upstream, exp_stage, use_url): dvc.experiments.push(remote, [ref_info1.name]) assert git_upstream.tmp_dir.scm.get_ref(str(ref_info1)) == exp1 + dvc.experiments.push(remote) + assert git_upstream.tmp_dir.scm.get_ref(str(ref_info3)) == exp3 + @pytest.mark.parametrize("all_,rev,result3", [(True, False, True), (False, True, None)]) def test_push_args(tmp_dir, scm, dvc, git_upstream, exp_stage, all_, rev, result3): @@ -173,6 +176,11 @@ def test_list_remote(tmp_dir, scm, dvc, git_downstream, exp_stage, use_url): def test_pull(tmp_dir, scm, dvc, git_downstream, exp_stage, use_url): from dvc.exceptions import InvalidArgumentError + # fetch and checkout to downstream so both repos start from same commit + downstream_repo = git_downstream.tmp_dir.scm.gitpython.repo + fetched = downstream_repo.remote(git_downstream.remote).fetch() + downstream_repo.git.checkout(fetched) + remote = git_downstream.url if use_url else git_downstream.remote downstream_exp = git_downstream.tmp_dir.dvc.experiments with pytest.raises(InvalidArgumentError): @@ -200,6 +208,9 @@ def test_pull(tmp_dir, scm, dvc, git_downstream, exp_stage, use_url): downstream_exp.pull(remote, [str(ref_info1)]) assert git_downstream.tmp_dir.scm.get_ref(str(ref_info1)) == exp1 + downstream_exp.pull(remote) + assert git_downstream.tmp_dir.scm.get_ref(str(ref_info3)) == exp3 + @pytest.mark.parametrize("all_,rev,result3", [(True, False, True), (False, True, None)]) def test_pull_args(tmp_dir, scm, dvc, git_downstream, exp_stage, all_, rev, result3): diff --git a/tests/unit/command/test_experiments.py b/tests/unit/command/test_experiments.py index 030a02fa12..38ff8001f8 100644 --- a/tests/unit/command/test_experiments.py +++ b/tests/unit/command/test_experiments.py @@ -294,12 +294,7 @@ def test_experiments_push(dvc, scm, mocker): ) cmd = cli_args.func(cli_args) - with pytest.raises(InvalidArgumentError) as exp_info: - cmd.run() - assert ( - str(exp_info.value) == "Either provide an `experiment` argument" - ", or use the `--rev` or `--all-commits` flag." - ) + assert cmd.run() == 0 def test_experiments_pull(dvc, scm, mocker): @@ -351,12 +346,7 @@ def test_experiments_pull(dvc, scm, mocker): ) cmd = cli_args.func(cli_args) - with pytest.raises(InvalidArgumentError) as exp_info: - cmd.run() - assert ( - str(exp_info.value) == "Either provide an `experiment` argument" - ", or use the `--rev` or `--all-commits` flag." - ) + assert cmd.run() == 0 def test_experiments_remove_flag(dvc, scm, mocker, capsys, caplog):