From 8913fea109ab085f2526bd2a610fdcc6286183a4 Mon Sep 17 00:00:00 2001 From: karajan1001 Date: Tue, 19 Apr 2022 19:12:20 +0800 Subject: [PATCH] exp pull/push: better result handling in exp sharing(#7448) fix: #7448 Current we skip on those experiments which already existed on both sides, but do not exclude them from the result list. And if we have diverged experiment, pulling/pushing them will stop the current progress and raise an exception, but didn't give any information on the condition of other experiments. 1. Change the exp pull and push to match the new API in `scmrepo`. 2. Handle sync status after `exp pull/push` finished. 3. Bump scmrepo to 0.0.19 Co-authored-by: --- dvc/repo/experiments/pull.py | 53 ++++++++++------- dvc/repo/experiments/push.py | 69 +++++++++++++---------- dvc/repo/experiments/utils.py | 32 +++++++++-- setup.cfg | 2 +- tests/func/experiments/test_remote.py | 7 +-- tests/unit/repo/experiments/test_utils.py | 2 +- 6 files changed, 106 insertions(+), 59 deletions(-) diff --git a/dvc/repo/experiments/pull.py b/dvc/repo/experiments/pull.py index 488e6a3cae..89ef960361 100644 --- a/dvc/repo/experiments/pull.py +++ b/dvc/repo/experiments/pull.py @@ -1,10 +1,13 @@ import logging -from typing import Iterable, Optional, Set, Union +from typing import Iterable, List, Mapping, Optional, Set, Union + +from funcy import group_by +from scmrepo.git.backend.base import SyncStatus -from dvc.exceptions import DvcException from dvc.repo import locked from dvc.repo.scm_context import scm_context from dvc.scm import TqdmGit, iter_revs +from dvc.ui import ui from .base import ExpRefInfo from .exceptions import UnresolvedExpNamesError @@ -52,40 +55,52 @@ def pull( for _, ref_info_list in ref_info_dict.items(): exp_ref_set.update(ref_info_list) - _pull(repo, git_remote, exp_ref_set, force) + pull_result = _pull(repo, git_remote, exp_ref_set, force) + + if pull_result[SyncStatus.DIVERGED]: + diverged_refs = [ref.name for ref in pull_result[SyncStatus.DIVERGED]] + ui.warn( + f"Local experiment '{diverged_refs}' has diverged from remote " + "experiment with the same name. To override the local experiment " + "re-run with '--force'." + ) + if pull_cache: - _pull_cache(repo, exp_ref_set, **kwargs) - return [ref.name for ref in exp_ref_set] + pull_cache_ref = ( + pull_result[SyncStatus.UP_TO_DATE] + + pull_result[SyncStatus.SUCCESS] + ) + _pull_cache(repo, pull_cache_ref, **kwargs) + + return [ref.name for ref in pull_result[SyncStatus.SUCCESS]] def _pull( repo, git_remote: str, - refs, + refs: Iterable["ExpRefInfo"], force: bool, -): - def on_diverged(refname: str, rev: str) -> bool: - if repo.scm.get_ref(refname) == rev: - return True - exp_name = refname.split("/")[-1] - raise DvcException( - f"Local experiment '{exp_name}' has diverged from remote " - "experiment with the same name. To override the local experiment " - "re-run with '--force'." - ) - +) -> Mapping[SyncStatus, List["ExpRefInfo"]]: refspec_list = [f"{exp_ref}:{exp_ref}" for exp_ref in refs] logger.debug(f"git pull experiment '{git_remote}' -> '{refspec_list}'") with TqdmGit(desc="Fetching git refs") as pbar: - repo.scm.fetch_refspecs( + results: Mapping[str, SyncStatus] = repo.scm.fetch_refspecs( git_remote, refspec_list, force=force, - on_diverged=on_diverged, progress=pbar.update_git, ) + def group_result(refspec): + return results[str(refspec)] + + pull_result: Mapping[SyncStatus, List["ExpRefInfo"]] = group_by( + group_result, refs + ) + + return pull_result + def _pull_cache( repo, diff --git a/dvc/repo/experiments/push.py b/dvc/repo/experiments/push.py index 308822ecbb..7356a27e08 100644 --- a/dvc/repo/experiments/push.py +++ b/dvc/repo/experiments/push.py @@ -1,20 +1,17 @@ import logging -from typing import Iterable, Optional, Set, Union +from typing import Iterable, List, Mapping, Optional, Set, Union + +from funcy import group_by +from scmrepo.git.backend.base import SyncStatus -from dvc.exceptions import DvcException from dvc.repo import locked from dvc.repo.scm_context import scm_context from dvc.scm import TqdmGit, iter_revs +from dvc.ui import ui from .base import ExpRefInfo from .exceptions import UnresolvedExpNamesError -from .utils import ( - exp_commits, - exp_refs, - exp_refs_by_baseline, - push_refspec, - resolve_name, -) +from .utils import exp_commits, exp_refs, exp_refs_by_baseline, resolve_name logger = logging.getLogger(__name__) @@ -60,10 +57,21 @@ def push( for _, ref_info_list in ref_info_dict.items(): exp_ref_set.update(ref_info_list) - _push(repo, git_remote, exp_ref_set, force) + push_result = _push(repo, git_remote, exp_ref_set, force) + if push_result[SyncStatus.DIVERGED]: + diverged_refs = [ref.name for ref in push_result[SyncStatus.DIVERGED]] + ui.warn( + f"Local experiment '{diverged_refs}' has diverged from remote " + "experiment with the same name. To override the remote experiment " + "re-run with '--force'." + ) if push_cache: - _push_cache(repo, exp_ref_set, **kwargs) - return [ref.name for ref in exp_ref_set] + push_cache_ref = ( + push_result[SyncStatus.UP_TO_DATE] + + push_result[SyncStatus.SUCCESS] + ) + _push_cache(repo, push_cache_ref, **kwargs) + return [ref.name for ref in push_result[SyncStatus.SUCCESS]] def _push( @@ -71,30 +79,33 @@ def _push( git_remote: str, refs: Iterable["ExpRefInfo"], force: bool, -): - def on_diverged(refname: str, rev: str) -> bool: - if repo.scm.get_ref(refname) == rev: - return True - exp_name = refname.split("/")[-1] - raise DvcException( - f"Local experiment '{exp_name}' has diverged from remote " - "experiment with the same name. To override the remote experiment " - "re-run with '--force'." - ) +) -> Mapping[SyncStatus, List["ExpRefInfo"]]: + from scmrepo.exceptions import AuthError + + from ...scm import GitAuthError + refspec_list = [f"{exp_ref}:{exp_ref}" for exp_ref in refs] logger.debug(f"git push experiment '{refs}' -> '{git_remote}'") - for exp_ref in refs: - with TqdmGit(desc="Pushing git refs") as pbar: - push_refspec( - repo.scm, + with TqdmGit(desc="Pushing git refs") as pbar: + try: + results: Mapping[str, SyncStatus] = repo.scm.push_refspecs( git_remote, - str(exp_ref), - str(exp_ref), + refspec_list, force=force, - on_diverged=on_diverged, progress=pbar.update_git, ) + except AuthError as exc: + raise GitAuthError(str(exc)) + + def group_result(refspec): + return results[str(refspec)] + + pull_result: Mapping[SyncStatus, List["ExpRefInfo"]] = group_by( + group_result, refs + ) + + return pull_result def _push_cache( diff --git a/dvc/repo/experiments/utils.py b/dvc/repo/experiments/utils.py index f6ba77713f..a61b27b64f 100644 --- a/dvc/repo/experiments/utils.py +++ b/dvc/repo/experiments/utils.py @@ -86,13 +86,37 @@ def push_refspec( **kwargs, ): from scmrepo.exceptions import AuthError - - from ...scm import GitAuthError + from scmrepo.git.backend.base import SyncStatus + + from ...scm import GitAuthError, SCMError + + if not src: + refspecs = [f":{dest}"] + elif src.endswith("/"): + refspecs = [] + dest = dest.rstrip("/") + "/" + for ref in scm.iter_refs(base=src): + refname = ref.split("/")[-1] + refspecs.append(f"{ref}:{dest}{refname}") + else: + if dest.endswith("/"): + refname = src.split("/")[-1] + refspecs = [f"{src}:{dest}/{refname}"] + else: + refspecs = [f"{src}:{dest}"] try: - return scm.push_refspec( - url, src, dest, force=force, on_diverged=on_diverged, **kwargs + results = scm.push_refspecs( + url, refspecs, force=force, on_diverged=on_diverged, **kwargs ) + diverged = [ + ref for ref in results if results[ref] == SyncStatus.DIVERGED + ] + + if diverged: + raise SCMError( + f"local ref '{diverged}' diverged from remote '{url}'" + ) except AuthError as exc: raise GitAuthError(str(exc)) diff --git a/setup.cfg b/setup.cfg index 075be26124..c790267d4a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -72,7 +72,7 @@ install_requires = aiohttp-retry>=2.4.5 diskcache>=5.2.1 jaraco.windows>=5.7.0; python_version < '3.8' and sys_platform == 'win32' - scmrepo==0.0.18 + scmrepo==0.0.19 dvc-render==0.0.5 dvclive>=0.7.3 diff --git a/tests/func/experiments/test_remote.py b/tests/func/experiments/test_remote.py index efcb2eea55..060de1a2f7 100644 --- a/tests/func/experiments/test_remote.py +++ b/tests/func/experiments/test_remote.py @@ -3,7 +3,6 @@ import pytest from funcy import first -from dvc.exceptions import DvcException from dvc.repo.experiments.utils import exp_refs_by_rev @@ -81,8 +80,7 @@ def test_push_diverged(tmp_dir, scm, dvc, git_upstream, exp_stage): git_upstream.tmp_dir.scm.set_ref(str(ref_info), remote_rev) - with pytest.raises(DvcException): - dvc.experiments.push(git_upstream.remote, [ref_info.name]) + assert dvc.experiments.push(git_upstream.remote, [ref_info.name]) == [] assert git_upstream.tmp_dir.scm.get_ref(str(ref_info)) == remote_rev dvc.experiments.push(git_upstream.remote, [ref_info.name], force=True) @@ -254,8 +252,7 @@ def test_pull_diverged(tmp_dir, scm, dvc, git_downstream, exp_stage): git_downstream.tmp_dir.scm.set_ref(str(ref_info), remote_rev) downstream_exp = git_downstream.tmp_dir.dvc.experiments - with pytest.raises(DvcException): - downstream_exp.pull(git_downstream.remote, ref_info.name) + assert downstream_exp.pull(git_downstream.remote, ref_info.name) == [] assert git_downstream.tmp_dir.scm.get_ref(str(ref_info)) == remote_rev downstream_exp.pull(git_downstream.remote, ref_info.name, force=True) diff --git a/tests/unit/repo/experiments/test_utils.py b/tests/unit/repo/experiments/test_utils.py index ae41154918..5ce9f37462 100644 --- a/tests/unit/repo/experiments/test_utils.py +++ b/tests/unit/repo/experiments/test_utils.py @@ -23,7 +23,7 @@ def test_resolve_exp_ref(tmp_dir, scm, git_upstream, name_only, use_url): assert str(result[name]) == ref assert result["notexist"] is None - scm.push_refspec(git_upstream.url, ref, ref) + scm.push_refspecs(git_upstream.url, f"{ref}:{ref}") remote = git_upstream.url if use_url else git_upstream.remote name = "foo" if name_only else ref remote_ref_info = resolve_name(scm, [name], remote)[name]