Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions dvc/commands/experiments/pull.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,13 @@

from dvc.cli.command import CmdBase
from dvc.cli.utils import append_doc_link
from dvc.exceptions import InvalidArgumentError
from dvc.ui import ui

logger = logging.getLogger(__name__)


class CmdExperimentsPull(CmdBase):
def raise_error_if_all_disabled(self):
if not any([self.args.experiment, self.args.all_commits, self.args.rev]):
raise InvalidArgumentError(
"Either provide an `experiment` argument, or use the "
"`--rev` or `--all-commits` flag."
)

def run(self):
self.raise_error_if_all_disabled()

pulled_exps = self.repo.experiments.pull(
self.args.git_remote,
self.args.experiment,
Expand Down
10 changes: 0 additions & 10 deletions dvc/commands/experiments/push.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,12 @@
from dvc.cli import completion
from dvc.cli.command import CmdBase
from dvc.cli.utils import append_doc_link
from dvc.exceptions import InvalidArgumentError
from dvc.ui import ui

logger = logging.getLogger(__name__)


class CmdExperimentsPush(CmdBase):
def raise_error_if_all_disabled(self):
if not any([self.args.experiment, self.args.all_commits, self.args.rev]):
raise InvalidArgumentError(
"Either provide an `experiment` argument, or use the "
"`--rev` or `--all-commits` flag."
)

@staticmethod
def log_result(result: Dict[str, Any], remote: str):
from dvc.utils import humanize
Expand Down Expand Up @@ -59,8 +51,6 @@ def join_exps(exps):
def run(self):
from dvc.repo.experiments.push import UploadError

self.raise_error_if_all_disabled()

try:
result = self.repo.experiments.push(
self.args.git_remote,
Expand Down
48 changes: 24 additions & 24 deletions dvc/repo/experiments/pull.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
def pull( # noqa: C901
repo,
git_remote: str,
exp_names: Union[Iterable[str], str],
exp_names: Optional[Union[Iterable[str], str]] = None,
all_commits=False,
rev: Optional[Union[List[str], str]] = None,
num=1,
Expand All @@ -32,30 +32,30 @@ def pull( # noqa: C901
exp_ref_set: Set["ExpRefInfo"] = set()
if all_commits:
exp_ref_set.update(exp_refs(repo.scm, git_remote))
elif exp_names:
if isinstance(exp_names, str):
exp_names = [exp_names]
exp_ref_dict = resolve_name(repo.scm, exp_names, git_remote)

unresolved_exp_names = []
for exp_name, exp_ref in exp_ref_dict.items():
if exp_ref is None:
unresolved_exp_names.append(exp_name)
else:
exp_ref_set.add(exp_ref)

if unresolved_exp_names:
raise UnresolvedExpNamesError(unresolved_exp_names)

else:
if exp_names:
if isinstance(exp_names, str):
exp_names = [exp_names]
exp_ref_dict = resolve_name(repo.scm, exp_names, git_remote)

unresolved_exp_names = []
for exp_name, exp_ref in exp_ref_dict.items():
if exp_ref is None:
unresolved_exp_names.append(exp_name)
else:
exp_ref_set.add(exp_ref)

if unresolved_exp_names:
raise UnresolvedExpNamesError(unresolved_exp_names)

if rev:
if isinstance(rev, str):
rev = [rev]
rev_dict = iter_revs(repo.scm, rev, num)
rev_set = set(rev_dict.keys())
ref_info_dict = exp_refs_by_baseline(repo.scm, rev_set, git_remote)
for _, ref_info_list in ref_info_dict.items():
exp_ref_set.update(ref_info_list)
rev = rev or "HEAD"
if isinstance(rev, str):
rev = [rev]
rev_dict = iter_revs(repo.scm, rev, num)
rev_set = set(rev_dict.keys())
ref_info_dict = exp_refs_by_baseline(repo.scm, rev_set, git_remote)
for _, ref_info_list in ref_info_dict.items():
exp_ref_set.update(ref_info_list)

pull_result = _pull(repo, git_remote, exp_ref_set, force)

Expand Down
5 changes: 3 additions & 2 deletions dvc/repo/experiments/push.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def exp_refs_from_rev(scm: "Git", rev: List[str], num: int = 1) -> Set["ExpRefIn
def push(
repo: "Repo",
git_remote: str,
exp_names: Union[List[str], str],
exp_names: Optional[Union[List[str], str]] = None,
all_commits: bool = False,
rev: Optional[Union[List[str], str]] = None,
num: int = 1,
Expand All @@ -111,7 +111,8 @@ def push(
exp_ref_set.update(exp_refs(repo.scm))
if exp_names:
exp_ref_set.update(exp_refs_from_names(repo.scm, ensure_list(exp_names)))
if rev:
else:
rev = rev or "HEAD"
if isinstance(rev, str):
rev = [rev]
exp_ref_set.update(exp_refs_from_rev(repo.scm, rev, num=num))
Expand Down
11 changes: 11 additions & 0 deletions tests/func/experiments/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ def test_push(tmp_dir, scm, dvc, git_upstream, exp_stage, use_url):
dvc.experiments.push(remote, [ref_info1.name])
assert git_upstream.tmp_dir.scm.get_ref(str(ref_info1)) == exp1

dvc.experiments.push(remote)
assert git_upstream.tmp_dir.scm.get_ref(str(ref_info3)) == exp3


@pytest.mark.parametrize("all_,rev,result3", [(True, False, True), (False, True, None)])
def test_push_args(tmp_dir, scm, dvc, git_upstream, exp_stage, all_, rev, result3):
Expand Down Expand Up @@ -173,6 +176,11 @@ def test_list_remote(tmp_dir, scm, dvc, git_downstream, exp_stage, use_url):
def test_pull(tmp_dir, scm, dvc, git_downstream, exp_stage, use_url):
from dvc.exceptions import InvalidArgumentError

# fetch and checkout to downstream so both repos start from same commit
downstream_repo = git_downstream.tmp_dir.scm.gitpython.repo
fetched = downstream_repo.remote(git_downstream.remote).fetch()
downstream_repo.git.checkout(fetched)

remote = git_downstream.url if use_url else git_downstream.remote
downstream_exp = git_downstream.tmp_dir.dvc.experiments
with pytest.raises(InvalidArgumentError):
Expand Down Expand Up @@ -200,6 +208,9 @@ def test_pull(tmp_dir, scm, dvc, git_downstream, exp_stage, use_url):
downstream_exp.pull(remote, [str(ref_info1)])
assert git_downstream.tmp_dir.scm.get_ref(str(ref_info1)) == exp1

downstream_exp.pull(remote)
assert git_downstream.tmp_dir.scm.get_ref(str(ref_info3)) == exp3


@pytest.mark.parametrize("all_,rev,result3", [(True, False, True), (False, True, None)])
def test_pull_args(tmp_dir, scm, dvc, git_downstream, exp_stage, all_, rev, result3):
Expand Down
14 changes: 2 additions & 12 deletions tests/unit/command/test_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,12 +294,7 @@ def test_experiments_push(dvc, scm, mocker):
)
cmd = cli_args.func(cli_args)

with pytest.raises(InvalidArgumentError) as exp_info:
cmd.run()
assert (
str(exp_info.value) == "Either provide an `experiment` argument"
", or use the `--rev` or `--all-commits` flag."
)
assert cmd.run() == 0


def test_experiments_pull(dvc, scm, mocker):
Expand Down Expand Up @@ -351,12 +346,7 @@ def test_experiments_pull(dvc, scm, mocker):
)
cmd = cli_args.func(cli_args)

with pytest.raises(InvalidArgumentError) as exp_info:
cmd.run()
assert (
str(exp_info.value) == "Either provide an `experiment` argument"
", or use the `--rev` or `--all-commits` flag."
)
assert cmd.run() == 0


def test_experiments_remove_flag(dvc, scm, mocker, capsys, caplog):
Expand Down