diff --git a/dvc/repo/experiments/pull.py b/dvc/repo/experiments/pull.py index cd649d8206..d2440bf4a6 100644 --- a/dvc/repo/experiments/pull.py +++ b/dvc/repo/experiments/pull.py @@ -4,7 +4,7 @@ from dvc.repo import locked from dvc.repo.scm_context import scm_context -from .utils import exp_commits, resolve_exp_ref +from .utils import exp_commits, resolve_name logger = logging.getLogger(__name__) @@ -14,7 +14,8 @@ def pull( repo, git_remote, exp_name, *args, force=False, pull_cache=False, **kwargs ): - exp_ref = resolve_exp_ref(repo.scm, exp_name, git_remote) + exp_ref_dict = resolve_name(repo.scm, exp_name, git_remote) + exp_ref = exp_ref_dict[exp_name] if not exp_ref: raise InvalidArgumentError( f"Experiment '{exp_name}' does not exist in '{git_remote}'" diff --git a/dvc/repo/experiments/push.py b/dvc/repo/experiments/push.py index 6c77644440..b0b43e4828 100644 --- a/dvc/repo/experiments/push.py +++ b/dvc/repo/experiments/push.py @@ -4,7 +4,7 @@ from dvc.repo import locked from dvc.repo.scm_context import scm_context -from .utils import exp_commits, push_refspec, resolve_exp_ref +from .utils import exp_commits, push_refspec, resolve_name logger = logging.getLogger(__name__) @@ -20,7 +20,8 @@ def push( push_cache=False, **kwargs, ): - exp_ref = resolve_exp_ref(repo.scm, exp_name) + exp_ref_dict = resolve_name(repo.scm, exp_name) + exp_ref = exp_ref_dict[exp_name] if not exp_ref: raise InvalidArgumentError( f"'{exp_name}' is not a valid experiment name" diff --git a/dvc/repo/experiments/remove.py b/dvc/repo/experiments/remove.py index 08a7396a1a..65c1ac56cf 100644 --- a/dvc/repo/experiments/remove.py +++ b/dvc/repo/experiments/remove.py @@ -6,7 +6,7 @@ from dvc.repo.scm_context import scm_context from dvc.scm import RevError -from .utils import exp_refs, push_refspec, remove_exp_refs, resolve_exp_ref +from .utils import exp_refs, push_refspec, remove_exp_refs, resolve_name logger = logging.getLogger(__name__) @@ -69,9 +69,8 @@ def _remove_commited_exps( ) -> List[str]: remain_list = [] remove_list = [] - for exp_name in exp_names: - ref_info = resolve_exp_ref(repo.scm, exp_name, remote) - + ref_info_dict = resolve_name(repo.scm, exp_names, remote) + for exp_name, ref_info in ref_info_dict.items(): if ref_info: remove_list.append(ref_info) else: diff --git a/dvc/repo/experiments/utils.py b/dvc/repo/experiments/utils.py index 6a87bd8999..c645c32242 100644 --- a/dvc/repo/experiments/utils.py +++ b/dvc/repo/experiments/utils.py @@ -1,4 +1,14 @@ -from typing import Callable, Generator, Iterable, Optional, Set +from collections import defaultdict +from typing import ( + Callable, + Dict, + Generator, + Iterable, + List, + Optional, + Set, + Union, +) from scmrepo.git import Git @@ -13,6 +23,23 @@ ) +class AmbiguousExpRefInfo(InvalidArgumentError): + def __init__( + self, + exp_name: str, + exp_ref_list: Iterable[ExpRefInfo], + ): + msg = [ + ( + f"Ambiguous name '{exp_name}' refers to multiple experiments." + " Use one of the following full refnames instead:" + ), + "", + ] + msg.extend([f"\t{info}" for info in exp_ref_list]) + super().__init__("\n".join(msg)) + + def exp_refs(scm: "Git") -> Generator["ExpRefInfo", None, None]: """Iterate over all experiment refs.""" for ref in scm.iter_refs(base=EXPS_NAMESPACE): @@ -30,15 +57,6 @@ def exp_refs_by_rev( yield ExpRefInfo.from_ref(ref) -def exp_refs_by_name( - scm: "Git", name: str -) -> Generator["ExpRefInfo", None, None]: - """Iterate over all experiment refs matching the specified name.""" - for ref_info in exp_refs(scm): - if ref_info.name == name: - yield ref_info - - def exp_refs_by_baseline( scm: "Git", rev: str ) -> Generator["ExpRefInfo", None, None]: @@ -96,13 +114,17 @@ def remote_exp_refs( yield ExpRefInfo.from_ref(ref) -def remote_exp_refs_by_name( - scm: "Git", url: str, name: str -) -> Generator["ExpRefInfo", None, None]: - """Iterate over all remote experiment refs matching the specified name.""" - for ref_info in remote_exp_refs(scm, url): - if ref_info.name == name: - yield ref_info +def exp_refs_by_names( + scm: "Git", names: Set[str], url: Optional[str] = None +) -> Dict[str, List[ExpRefInfo]]: + """Iterate over all experiment refs matching the specified names.""" + resolve_results = defaultdict(list) + ref_info_gen = remote_exp_refs(scm, url) if url else exp_refs(scm) + for ref_info in ref_info_gen: + if ref_info.name in names: + resolve_results[ref_info.name].append(ref_info) + + return resolve_results def remote_exp_refs_by_baseline( @@ -155,45 +177,39 @@ def fix_exp_head(scm: "Git", ref: Optional[str]) -> Optional[str]: return ref -def resolve_exp_ref( - scm, exp_name: str, git_remote: Optional[str] = None -) -> Optional[ExpRefInfo]: - if exp_name.startswith("refs/"): - return ExpRefInfo.from_ref(exp_name) - - if git_remote: - exp_ref_list = list(remote_exp_refs_by_name(scm, git_remote, exp_name)) - else: - exp_ref_list = list(exp_refs_by_name(scm, exp_name)) - - if not exp_ref_list: - return None - if len(exp_ref_list) > 1: - cur_rev = scm.get_rev() - for info in exp_ref_list: - if info.baseline_sha == cur_rev: - return info - if git_remote: - msg = [ - ( - f"Ambiguous name '{exp_name}' refers to multiple " - "experiments. Use full refname to push one of the " - "following:" - ), - "", - ] +def resolve_name( + scm: "Git", + exp_names: Union[Iterable[str], str], + git_remote: Optional[str] = None, +) -> Dict[str, Optional[ExpRefInfo]]: + """find the ref_info of specified names.""" + if isinstance(exp_names, str): + exp_names = [exp_names] + + result = {} + unresolved = set() + for exp_name in exp_names: + if exp_name.startswith("refs/"): + result[exp_name] = ExpRefInfo.from_ref(exp_name) else: - msg = [ - ( - f"Ambiguous name '{exp_name}' refers to multiple " - f"experiments in '{git_remote}'. Use full refname to pull " - "one of the following:" - ), - "", - ] - msg.extend([f"\t{info}" for info in exp_ref_list]) - raise InvalidArgumentError("\n".join(msg)) - return exp_ref_list[0] + unresolved.add(exp_name) + + unresolved_result = exp_refs_by_names(scm, unresolved, git_remote) + cur_rev = scm.get_rev() + for name in unresolved: + ref_info_list = unresolved_result[name] + if not ref_info_list: + result[name] = None + elif len(ref_info_list) == 1: + result[name] = ref_info_list[0] + else: + for ref_info in ref_info_list: + if ref_info.baseline_sha == cur_rev: + result[name] = ref_info + break + else: + raise AmbiguousExpRefInfo(name, ref_info_list) + return result def check_ref_format(scm: "Git", ref: ExpRefInfo): diff --git a/dvc/scm.py b/dvc/scm.py index c7b6ab551d..a9eee7b17e 100644 --- a/dvc/scm.py +++ b/dvc/scm.py @@ -114,12 +114,17 @@ def resolve_rev(scm: "Git", rev: str) -> str: except InternalRevError as exc: # `scm` will only resolve git branch and tag names, # if rev is not a sha it may be an abbreviated experiment name - if not scm.is_sha(rev) and not rev.startswith("refs/"): - from dvc.repo.experiments.utils import exp_refs_by_name - - ref_infos = list(exp_refs_by_name(scm, rev)) - if len(ref_infos) == 1: - return scm.get_ref(str(ref_infos[0])) - if len(ref_infos) > 1: + if not rev.startswith("refs/"): + from dvc.repo.experiments.utils import ( + AmbiguousExpRefInfo, + resolve_name, + ) + + try: + ref_infos = resolve_name(scm, rev).get(rev) + except AmbiguousExpRefInfo: raise RevError(f"ambiguous Git revision '{rev}'") + if ref_infos: + return scm.get_ref(str(ref_infos)) + raise RevError(str(exc)) diff --git a/tests/unit/repo/experiments/test_utils.py b/tests/unit/repo/experiments/test_utils.py index 0d5cb8ddf0..ae41154918 100644 --- a/tests/unit/repo/experiments/test_utils.py +++ b/tests/unit/repo/experiments/test_utils.py @@ -2,7 +2,7 @@ from dvc.exceptions import InvalidArgumentError from dvc.repo.experiments.base import EXPS_NAMESPACE, ExpRefInfo -from dvc.repo.experiments.utils import check_ref_format, resolve_exp_ref +from dvc.repo.experiments.utils import check_ref_format, resolve_name def commit_exp_ref(tmp_dir, scm, file="foo", contents="foo", name="foo"): @@ -17,13 +17,16 @@ def commit_exp_ref(tmp_dir, scm, file="foo", contents="foo", name="foo"): @pytest.mark.parametrize("name_only", [True, False]) def test_resolve_exp_ref(tmp_dir, scm, git_upstream, name_only, use_url): ref, _ = commit_exp_ref(tmp_dir, scm) - ref_info = resolve_exp_ref(scm, "foo" if name_only else ref) - assert isinstance(ref_info, ExpRefInfo) - assert str(ref_info) == ref + name = "foo" if name_only else ref + result = resolve_name(scm, [name, "notexist"]) + assert isinstance(result[name], ExpRefInfo) + assert str(result[name]) == ref + assert result["notexist"] is None scm.push_refspec(git_upstream.url, ref, ref) remote = git_upstream.url if use_url else git_upstream.remote - remote_ref_info = resolve_exp_ref(scm, "foo" if name_only else ref, remote) + name = "foo" if name_only else ref + remote_ref_info = resolve_name(scm, [name], remote)[name] assert isinstance(remote_ref_info, ExpRefInfo) assert str(remote_ref_info) == ref