diff --git a/dvc/repo/experiments/pull.py b/dvc/repo/experiments/pull.py index 488e6a3cae..89ef960361 100644 --- a/dvc/repo/experiments/pull.py +++ b/dvc/repo/experiments/pull.py @@ -1,10 +1,13 @@ import logging -from typing import Iterable, Optional, Set, Union +from typing import Iterable, List, Mapping, Optional, Set, Union + +from funcy import group_by +from scmrepo.git.backend.base import SyncStatus -from dvc.exceptions import DvcException from dvc.repo import locked from dvc.repo.scm_context import scm_context from dvc.scm import TqdmGit, iter_revs +from dvc.ui import ui from .base import ExpRefInfo from .exceptions import UnresolvedExpNamesError @@ -52,40 +55,52 @@ def pull( for _, ref_info_list in ref_info_dict.items(): exp_ref_set.update(ref_info_list) - _pull(repo, git_remote, exp_ref_set, force) + pull_result = _pull(repo, git_remote, exp_ref_set, force) + + if pull_result[SyncStatus.DIVERGED]: + diverged_refs = [ref.name for ref in pull_result[SyncStatus.DIVERGED]] + ui.warn( + f"Local experiment '{diverged_refs}' has diverged from remote " + "experiment with the same name. To override the local experiment " + "re-run with '--force'." + ) + if pull_cache: - _pull_cache(repo, exp_ref_set, **kwargs) - return [ref.name for ref in exp_ref_set] + pull_cache_ref = ( + pull_result[SyncStatus.UP_TO_DATE] + + pull_result[SyncStatus.SUCCESS] + ) + _pull_cache(repo, pull_cache_ref, **kwargs) + + return [ref.name for ref in pull_result[SyncStatus.SUCCESS]] def _pull( repo, git_remote: str, - refs, + refs: Iterable["ExpRefInfo"], force: bool, -): - def on_diverged(refname: str, rev: str) -> bool: - if repo.scm.get_ref(refname) == rev: - return True - exp_name = refname.split("/")[-1] - raise DvcException( - f"Local experiment '{exp_name}' has diverged from remote " - "experiment with the same name. To override the local experiment " - "re-run with '--force'." - ) - +) -> Mapping[SyncStatus, List["ExpRefInfo"]]: refspec_list = [f"{exp_ref}:{exp_ref}" for exp_ref in refs] logger.debug(f"git pull experiment '{git_remote}' -> '{refspec_list}'") with TqdmGit(desc="Fetching git refs") as pbar: - repo.scm.fetch_refspecs( + results: Mapping[str, SyncStatus] = repo.scm.fetch_refspecs( git_remote, refspec_list, force=force, - on_diverged=on_diverged, progress=pbar.update_git, ) + def group_result(refspec): + return results[str(refspec)] + + pull_result: Mapping[SyncStatus, List["ExpRefInfo"]] = group_by( + group_result, refs + ) + + return pull_result + def _pull_cache( repo, diff --git a/dvc/repo/experiments/push.py b/dvc/repo/experiments/push.py index 308822ecbb..7356a27e08 100644 --- a/dvc/repo/experiments/push.py +++ b/dvc/repo/experiments/push.py @@ -1,20 +1,17 @@ import logging -from typing import Iterable, Optional, Set, Union +from typing import Iterable, List, Mapping, Optional, Set, Union + +from funcy import group_by +from scmrepo.git.backend.base import SyncStatus -from dvc.exceptions import DvcException from dvc.repo import locked from dvc.repo.scm_context import scm_context from dvc.scm import TqdmGit, iter_revs +from dvc.ui import ui from .base import ExpRefInfo from .exceptions import UnresolvedExpNamesError -from .utils import ( - exp_commits, - exp_refs, - exp_refs_by_baseline, - push_refspec, - resolve_name, -) +from .utils import exp_commits, exp_refs, exp_refs_by_baseline, resolve_name logger = logging.getLogger(__name__) @@ -60,10 +57,21 @@ def push( for _, ref_info_list in ref_info_dict.items(): exp_ref_set.update(ref_info_list) - _push(repo, git_remote, exp_ref_set, force) + push_result = _push(repo, git_remote, exp_ref_set, force) + if push_result[SyncStatus.DIVERGED]: + diverged_refs = [ref.name for ref in push_result[SyncStatus.DIVERGED]] + ui.warn( + f"Local experiment '{diverged_refs}' has diverged from remote " + "experiment with the same name. To override the remote experiment " + "re-run with '--force'." + ) if push_cache: - _push_cache(repo, exp_ref_set, **kwargs) - return [ref.name for ref in exp_ref_set] + push_cache_ref = ( + push_result[SyncStatus.UP_TO_DATE] + + push_result[SyncStatus.SUCCESS] + ) + _push_cache(repo, push_cache_ref, **kwargs) + return [ref.name for ref in push_result[SyncStatus.SUCCESS]] def _push( @@ -71,30 +79,33 @@ def _push( git_remote: str, refs: Iterable["ExpRefInfo"], force: bool, -): - def on_diverged(refname: str, rev: str) -> bool: - if repo.scm.get_ref(refname) == rev: - return True - exp_name = refname.split("/")[-1] - raise DvcException( - f"Local experiment '{exp_name}' has diverged from remote " - "experiment with the same name. To override the remote experiment " - "re-run with '--force'." - ) +) -> Mapping[SyncStatus, List["ExpRefInfo"]]: + from scmrepo.exceptions import AuthError + + from ...scm import GitAuthError + refspec_list = [f"{exp_ref}:{exp_ref}" for exp_ref in refs] logger.debug(f"git push experiment '{refs}' -> '{git_remote}'") - for exp_ref in refs: - with TqdmGit(desc="Pushing git refs") as pbar: - push_refspec( - repo.scm, + with TqdmGit(desc="Pushing git refs") as pbar: + try: + results: Mapping[str, SyncStatus] = repo.scm.push_refspecs( git_remote, - str(exp_ref), - str(exp_ref), + refspec_list, force=force, - on_diverged=on_diverged, progress=pbar.update_git, ) + except AuthError as exc: + raise GitAuthError(str(exc)) + + def group_result(refspec): + return results[str(refspec)] + + pull_result: Mapping[SyncStatus, List["ExpRefInfo"]] = group_by( + group_result, refs + ) + + return pull_result def _push_cache( diff --git a/dvc/repo/experiments/utils.py b/dvc/repo/experiments/utils.py index f6ba77713f..a61b27b64f 100644 --- a/dvc/repo/experiments/utils.py +++ b/dvc/repo/experiments/utils.py @@ -86,13 +86,37 @@ def push_refspec( **kwargs, ): from scmrepo.exceptions import AuthError - - from ...scm import GitAuthError + from scmrepo.git.backend.base import SyncStatus + + from ...scm import GitAuthError, SCMError + + if not src: + refspecs = [f":{dest}"] + elif src.endswith("/"): + refspecs = [] + dest = dest.rstrip("/") + "/" + for ref in scm.iter_refs(base=src): + refname = ref.split("/")[-1] + refspecs.append(f"{ref}:{dest}{refname}") + else: + if dest.endswith("/"): + refname = src.split("/")[-1] + refspecs = [f"{src}:{dest}/{refname}"] + else: + refspecs = [f"{src}:{dest}"] try: - return scm.push_refspec( - url, src, dest, force=force, on_diverged=on_diverged, **kwargs + results = scm.push_refspecs( + url, refspecs, force=force, on_diverged=on_diverged, **kwargs ) + diverged = [ + ref for ref in results if results[ref] == SyncStatus.DIVERGED + ] + + if diverged: + raise SCMError( + f"local ref '{diverged}' diverged from remote '{url}'" + ) except AuthError as exc: raise GitAuthError(str(exc)) diff --git a/setup.cfg b/setup.cfg index 075be26124..c790267d4a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -72,7 +72,7 @@ install_requires = aiohttp-retry>=2.4.5 diskcache>=5.2.1 jaraco.windows>=5.7.0; python_version < '3.8' and sys_platform == 'win32' - scmrepo==0.0.18 + scmrepo==0.0.19 dvc-render==0.0.5 dvclive>=0.7.3 diff --git a/tests/func/experiments/test_remote.py b/tests/func/experiments/test_remote.py index efcb2eea55..060de1a2f7 100644 --- a/tests/func/experiments/test_remote.py +++ b/tests/func/experiments/test_remote.py @@ -3,7 +3,6 @@ import pytest from funcy import first -from dvc.exceptions import DvcException from dvc.repo.experiments.utils import exp_refs_by_rev @@ -81,8 +80,7 @@ def test_push_diverged(tmp_dir, scm, dvc, git_upstream, exp_stage): git_upstream.tmp_dir.scm.set_ref(str(ref_info), remote_rev) - with pytest.raises(DvcException): - dvc.experiments.push(git_upstream.remote, [ref_info.name]) + assert dvc.experiments.push(git_upstream.remote, [ref_info.name]) == [] assert git_upstream.tmp_dir.scm.get_ref(str(ref_info)) == remote_rev dvc.experiments.push(git_upstream.remote, [ref_info.name], force=True) @@ -254,8 +252,7 @@ def test_pull_diverged(tmp_dir, scm, dvc, git_downstream, exp_stage): git_downstream.tmp_dir.scm.set_ref(str(ref_info), remote_rev) downstream_exp = git_downstream.tmp_dir.dvc.experiments - with pytest.raises(DvcException): - downstream_exp.pull(git_downstream.remote, ref_info.name) + assert downstream_exp.pull(git_downstream.remote, ref_info.name) == [] assert git_downstream.tmp_dir.scm.get_ref(str(ref_info)) == remote_rev downstream_exp.pull(git_downstream.remote, ref_info.name, force=True) diff --git a/tests/unit/repo/experiments/test_utils.py b/tests/unit/repo/experiments/test_utils.py index ae41154918..5ce9f37462 100644 --- a/tests/unit/repo/experiments/test_utils.py +++ b/tests/unit/repo/experiments/test_utils.py @@ -23,7 +23,7 @@ def test_resolve_exp_ref(tmp_dir, scm, git_upstream, name_only, use_url): assert str(result[name]) == ref assert result["notexist"] is None - scm.push_refspec(git_upstream.url, ref, ref) + scm.push_refspecs(git_upstream.url, f"{ref}:{ref}") remote = git_upstream.url if use_url else git_upstream.remote name = "foo" if name_only else ref remote_ref_info = resolve_name(scm, [name], remote)[name]