Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions dvc/repo/experiments/pull.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dvc.repo import locked
from dvc.repo.scm_context import scm_context

from .utils import exp_commits, resolve_exp_ref
from .utils import exp_commits, resolve_name

logger = logging.getLogger(__name__)

Expand All @@ -14,7 +14,8 @@
def pull(
repo, git_remote, exp_name, *args, force=False, pull_cache=False, **kwargs
):
exp_ref = resolve_exp_ref(repo.scm, exp_name, git_remote)
exp_ref_dict = resolve_name(repo.scm, exp_name, git_remote)
exp_ref = exp_ref_dict[exp_name]
if not exp_ref:
raise InvalidArgumentError(
f"Experiment '{exp_name}' does not exist in '{git_remote}'"
Expand Down
5 changes: 3 additions & 2 deletions dvc/repo/experiments/push.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dvc.repo import locked
from dvc.repo.scm_context import scm_context

from .utils import exp_commits, push_refspec, resolve_exp_ref
from .utils import exp_commits, push_refspec, resolve_name

logger = logging.getLogger(__name__)

Expand All @@ -20,7 +20,8 @@ def push(
push_cache=False,
**kwargs,
):
exp_ref = resolve_exp_ref(repo.scm, exp_name)
exp_ref_dict = resolve_name(repo.scm, exp_name)
exp_ref = exp_ref_dict[exp_name]
if not exp_ref:
raise InvalidArgumentError(
f"'{exp_name}' is not a valid experiment name"
Expand Down
7 changes: 3 additions & 4 deletions dvc/repo/experiments/remove.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dvc.repo.scm_context import scm_context
from dvc.scm import RevError

from .utils import exp_refs, push_refspec, remove_exp_refs, resolve_exp_ref
from .utils import exp_refs, push_refspec, remove_exp_refs, resolve_name

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -69,9 +69,8 @@ def _remove_commited_exps(
) -> List[str]:
remain_list = []
remove_list = []
for exp_name in exp_names:
ref_info = resolve_exp_ref(repo.scm, exp_name, remote)

ref_info_dict = resolve_name(repo.scm, exp_names, remote)
for exp_name, ref_info in ref_info_dict.items():
if ref_info:
remove_list.append(ref_info)
else:
Expand Down
126 changes: 71 additions & 55 deletions dvc/repo/experiments/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
from typing import Callable, Generator, Iterable, Optional, Set
from collections import defaultdict
from typing import (
Callable,
Dict,
Generator,
Iterable,
List,
Optional,
Set,
Union,
)

from scmrepo.git import Git

Expand All @@ -13,6 +23,23 @@
)


class AmbiguousExpRefInfo(InvalidArgumentError):
def __init__(
self,
exp_name: str,
exp_ref_list: Iterable[ExpRefInfo],
):
msg = [
(
f"Ambiguous name '{exp_name}' refers to multiple experiments."
" Use one of the following full refnames instead:"
),
"",
]
msg.extend([f"\t{info}" for info in exp_ref_list])
super().__init__("\n".join(msg))


def exp_refs(scm: "Git") -> Generator["ExpRefInfo", None, None]:
"""Iterate over all experiment refs."""
for ref in scm.iter_refs(base=EXPS_NAMESPACE):
Expand All @@ -30,15 +57,6 @@ def exp_refs_by_rev(
yield ExpRefInfo.from_ref(ref)


def exp_refs_by_name(
scm: "Git", name: str
) -> Generator["ExpRefInfo", None, None]:
"""Iterate over all experiment refs matching the specified name."""
for ref_info in exp_refs(scm):
if ref_info.name == name:
yield ref_info


def exp_refs_by_baseline(
scm: "Git", rev: str
) -> Generator["ExpRefInfo", None, None]:
Expand Down Expand Up @@ -96,13 +114,17 @@ def remote_exp_refs(
yield ExpRefInfo.from_ref(ref)


def remote_exp_refs_by_name(
scm: "Git", url: str, name: str
) -> Generator["ExpRefInfo", None, None]:
"""Iterate over all remote experiment refs matching the specified name."""
for ref_info in remote_exp_refs(scm, url):
if ref_info.name == name:
yield ref_info
def exp_refs_by_names(
scm: "Git", names: Set[str], url: Optional[str] = None
) -> Dict[str, List[ExpRefInfo]]:
"""Iterate over all experiment refs matching the specified names."""
resolve_results = defaultdict(list)
ref_info_gen = remote_exp_refs(scm, url) if url else exp_refs(scm)
for ref_info in ref_info_gen:
if ref_info.name in names:
resolve_results[ref_info.name].append(ref_info)

return resolve_results


def remote_exp_refs_by_baseline(
Expand Down Expand Up @@ -155,45 +177,39 @@ def fix_exp_head(scm: "Git", ref: Optional[str]) -> Optional[str]:
return ref


def resolve_exp_ref(
scm, exp_name: str, git_remote: Optional[str] = None
) -> Optional[ExpRefInfo]:
if exp_name.startswith("refs/"):
return ExpRefInfo.from_ref(exp_name)

if git_remote:
exp_ref_list = list(remote_exp_refs_by_name(scm, git_remote, exp_name))
else:
exp_ref_list = list(exp_refs_by_name(scm, exp_name))

if not exp_ref_list:
return None
if len(exp_ref_list) > 1:
cur_rev = scm.get_rev()
for info in exp_ref_list:
if info.baseline_sha == cur_rev:
return info
if git_remote:
msg = [
(
f"Ambiguous name '{exp_name}' refers to multiple "
"experiments. Use full refname to push one of the "
"following:"
),
"",
]
def resolve_name(
scm: "Git",
exp_names: Union[Iterable[str], str],
git_remote: Optional[str] = None,
) -> Dict[str, Optional[ExpRefInfo]]:
"""find the ref_info of specified names."""
if isinstance(exp_names, str):
exp_names = [exp_names]

result = {}
unresolved = set()
for exp_name in exp_names:
if exp_name.startswith("refs/"):
result[exp_name] = ExpRefInfo.from_ref(exp_name)
else:
msg = [
(
f"Ambiguous name '{exp_name}' refers to multiple "
f"experiments in '{git_remote}'. Use full refname to pull "
"one of the following:"
),
"",
]
msg.extend([f"\t{info}" for info in exp_ref_list])
raise InvalidArgumentError("\n".join(msg))
return exp_ref_list[0]
unresolved.add(exp_name)

unresolved_result = exp_refs_by_names(scm, unresolved, git_remote)
cur_rev = scm.get_rev()
for name in unresolved:
ref_info_list = unresolved_result[name]
if not ref_info_list:
result[name] = None
elif len(ref_info_list) == 1:
result[name] = ref_info_list[0]
else:
for ref_info in ref_info_list:
if ref_info.baseline_sha == cur_rev:
result[name] = ref_info
break
else:
raise AmbiguousExpRefInfo(name, ref_info_list)
return result


def check_ref_format(scm: "Git", ref: ExpRefInfo):
Expand Down
19 changes: 12 additions & 7 deletions dvc/scm.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,17 @@ def resolve_rev(scm: "Git", rev: str) -> str:
except InternalRevError as exc:
# `scm` will only resolve git branch and tag names,
# if rev is not a sha it may be an abbreviated experiment name
if not scm.is_sha(rev) and not rev.startswith("refs/"):
from dvc.repo.experiments.utils import exp_refs_by_name

ref_infos = list(exp_refs_by_name(scm, rev))
if len(ref_infos) == 1:
return scm.get_ref(str(ref_infos[0]))
if len(ref_infos) > 1:
if not rev.startswith("refs/"):
from dvc.repo.experiments.utils import (
AmbiguousExpRefInfo,
resolve_name,
)

try:
ref_infos = resolve_name(scm, rev).get(rev)
except AmbiguousExpRefInfo:
raise RevError(f"ambiguous Git revision '{rev}'")
if ref_infos:
return scm.get_ref(str(ref_infos))

raise RevError(str(exc))
13 changes: 8 additions & 5 deletions tests/unit/repo/experiments/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from dvc.exceptions import InvalidArgumentError
from dvc.repo.experiments.base import EXPS_NAMESPACE, ExpRefInfo
from dvc.repo.experiments.utils import check_ref_format, resolve_exp_ref
from dvc.repo.experiments.utils import check_ref_format, resolve_name


def commit_exp_ref(tmp_dir, scm, file="foo", contents="foo", name="foo"):
Expand All @@ -17,13 +17,16 @@ def commit_exp_ref(tmp_dir, scm, file="foo", contents="foo", name="foo"):
@pytest.mark.parametrize("name_only", [True, False])
def test_resolve_exp_ref(tmp_dir, scm, git_upstream, name_only, use_url):
ref, _ = commit_exp_ref(tmp_dir, scm)
ref_info = resolve_exp_ref(scm, "foo" if name_only else ref)
assert isinstance(ref_info, ExpRefInfo)
assert str(ref_info) == ref
name = "foo" if name_only else ref
result = resolve_name(scm, [name, "notexist"])
assert isinstance(result[name], ExpRefInfo)
assert str(result[name]) == ref
assert result["notexist"] is None

scm.push_refspec(git_upstream.url, ref, ref)
remote = git_upstream.url if use_url else git_upstream.remote
remote_ref_info = resolve_exp_ref(scm, "foo" if name_only else ref, remote)
name = "foo" if name_only else ref
remote_ref_info = resolve_name(scm, [name], remote)[name]
assert isinstance(remote_ref_info, ExpRefInfo)
assert str(remote_ref_info) == ref

Expand Down