Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 34 additions & 19 deletions dvc/repo/experiments/pull.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import logging
from typing import Iterable, Optional, Set, Union
from typing import Iterable, List, Mapping, Optional, Set, Union

from funcy import group_by
from scmrepo.git.backend.base import SyncStatus

from dvc.exceptions import DvcException
from dvc.repo import locked
from dvc.repo.scm_context import scm_context
from dvc.scm import TqdmGit, iter_revs
from dvc.ui import ui

from .base import ExpRefInfo
from .exceptions import UnresolvedExpNamesError
Expand Down Expand Up @@ -52,40 +55,52 @@ def pull(
for _, ref_info_list in ref_info_dict.items():
exp_ref_set.update(ref_info_list)

_pull(repo, git_remote, exp_ref_set, force)
pull_result = _pull(repo, git_remote, exp_ref_set, force)

if pull_result[SyncStatus.DIVERGED]:
diverged_refs = [ref.name for ref in pull_result[SyncStatus.DIVERGED]]
ui.warn(
f"Local experiment '{diverged_refs}' has diverged from remote "
"experiment with the same name. To override the local experiment "
"re-run with '--force'."
)

if pull_cache:
_pull_cache(repo, exp_ref_set, **kwargs)
return [ref.name for ref in exp_ref_set]
pull_cache_ref = (
pull_result[SyncStatus.UP_TO_DATE]
+ pull_result[SyncStatus.SUCCESS]
)
_pull_cache(repo, pull_cache_ref, **kwargs)

return [ref.name for ref in pull_result[SyncStatus.SUCCESS]]


def _pull(
repo,
git_remote: str,
refs,
refs: Iterable["ExpRefInfo"],
force: bool,
):
def on_diverged(refname: str, rev: str) -> bool:
if repo.scm.get_ref(refname) == rev:
return True
exp_name = refname.split("/")[-1]
raise DvcException(
f"Local experiment '{exp_name}' has diverged from remote "
"experiment with the same name. To override the local experiment "
"re-run with '--force'."
)

) -> Mapping[SyncStatus, List["ExpRefInfo"]]:
refspec_list = [f"{exp_ref}:{exp_ref}" for exp_ref in refs]
logger.debug(f"git pull experiment '{git_remote}' -> '{refspec_list}'")

with TqdmGit(desc="Fetching git refs") as pbar:
repo.scm.fetch_refspecs(
results: Mapping[str, SyncStatus] = repo.scm.fetch_refspecs(
git_remote,
refspec_list,
force=force,
on_diverged=on_diverged,
progress=pbar.update_git,
)

def group_result(refspec):
return results[str(refspec)]

pull_result: Mapping[SyncStatus, List["ExpRefInfo"]] = group_by(
group_result, refs
)

return pull_result


def _pull_cache(
repo,
Expand Down
69 changes: 40 additions & 29 deletions dvc/repo/experiments/push.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
import logging
from typing import Iterable, Optional, Set, Union
from typing import Iterable, List, Mapping, Optional, Set, Union

from funcy import group_by
from scmrepo.git.backend.base import SyncStatus

from dvc.exceptions import DvcException
from dvc.repo import locked
from dvc.repo.scm_context import scm_context
from dvc.scm import TqdmGit, iter_revs
from dvc.ui import ui

from .base import ExpRefInfo
from .exceptions import UnresolvedExpNamesError
from .utils import (
exp_commits,
exp_refs,
exp_refs_by_baseline,
push_refspec,
resolve_name,
)
from .utils import exp_commits, exp_refs, exp_refs_by_baseline, resolve_name

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -60,41 +57,55 @@ def push(
for _, ref_info_list in ref_info_dict.items():
exp_ref_set.update(ref_info_list)

_push(repo, git_remote, exp_ref_set, force)
push_result = _push(repo, git_remote, exp_ref_set, force)
if push_result[SyncStatus.DIVERGED]:
diverged_refs = [ref.name for ref in push_result[SyncStatus.DIVERGED]]
ui.warn(
f"Local experiment '{diverged_refs}' has diverged from remote "
"experiment with the same name. To override the remote experiment "
"re-run with '--force'."
)
if push_cache:
_push_cache(repo, exp_ref_set, **kwargs)
return [ref.name for ref in exp_ref_set]
push_cache_ref = (
push_result[SyncStatus.UP_TO_DATE]
+ push_result[SyncStatus.SUCCESS]
)
_push_cache(repo, push_cache_ref, **kwargs)
return [ref.name for ref in push_result[SyncStatus.SUCCESS]]


def _push(
repo,
git_remote: str,
refs: Iterable["ExpRefInfo"],
force: bool,
):
def on_diverged(refname: str, rev: str) -> bool:
if repo.scm.get_ref(refname) == rev:
return True
exp_name = refname.split("/")[-1]
raise DvcException(
f"Local experiment '{exp_name}' has diverged from remote "
"experiment with the same name. To override the remote experiment "
"re-run with '--force'."
)
) -> Mapping[SyncStatus, List["ExpRefInfo"]]:
from scmrepo.exceptions import AuthError

from ...scm import GitAuthError

refspec_list = [f"{exp_ref}:{exp_ref}" for exp_ref in refs]
logger.debug(f"git push experiment '{refs}' -> '{git_remote}'")

for exp_ref in refs:
with TqdmGit(desc="Pushing git refs") as pbar:
push_refspec(
repo.scm,
with TqdmGit(desc="Pushing git refs") as pbar:
try:
results: Mapping[str, SyncStatus] = repo.scm.push_refspecs(
git_remote,
str(exp_ref),
str(exp_ref),
refspec_list,
force=force,
on_diverged=on_diverged,
progress=pbar.update_git,
)
except AuthError as exc:
raise GitAuthError(str(exc))

def group_result(refspec):
return results[str(refspec)]

pull_result: Mapping[SyncStatus, List["ExpRefInfo"]] = group_by(
group_result, refs
)

return pull_result


def _push_cache(
Expand Down
32 changes: 28 additions & 4 deletions dvc/repo/experiments/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,37 @@ def push_refspec(
**kwargs,
):
from scmrepo.exceptions import AuthError

from ...scm import GitAuthError
from scmrepo.git.backend.base import SyncStatus

from ...scm import GitAuthError, SCMError

if not src:
refspecs = [f":{dest}"]
elif src.endswith("/"):
refspecs = []
dest = dest.rstrip("/") + "/"
for ref in scm.iter_refs(base=src):
refname = ref.split("/")[-1]
refspecs.append(f"{ref}:{dest}{refname}")
else:
if dest.endswith("/"):
refname = src.split("/")[-1]
refspecs = [f"{src}:{dest}/{refname}"]
else:
refspecs = [f"{src}:{dest}"]

try:
return scm.push_refspec(
url, src, dest, force=force, on_diverged=on_diverged, **kwargs
results = scm.push_refspecs(
url, refspecs, force=force, on_diverged=on_diverged, **kwargs
)
diverged = [
ref for ref in results if results[ref] == SyncStatus.DIVERGED
]

if diverged:
raise SCMError(
f"local ref '{diverged}' diverged from remote '{url}'"
)
except AuthError as exc:
raise GitAuthError(str(exc))

Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ install_requires =
aiohttp-retry>=2.4.5
diskcache>=5.2.1
jaraco.windows>=5.7.0; python_version < '3.8' and sys_platform == 'win32'
scmrepo==0.0.18
scmrepo==0.0.19
dvc-render==0.0.5
dvclive>=0.7.3

Expand Down
7 changes: 2 additions & 5 deletions tests/func/experiments/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import pytest
from funcy import first

from dvc.exceptions import DvcException
from dvc.repo.experiments.utils import exp_refs_by_rev


Expand Down Expand Up @@ -81,8 +80,7 @@ def test_push_diverged(tmp_dir, scm, dvc, git_upstream, exp_stage):

git_upstream.tmp_dir.scm.set_ref(str(ref_info), remote_rev)

with pytest.raises(DvcException):
dvc.experiments.push(git_upstream.remote, [ref_info.name])
assert dvc.experiments.push(git_upstream.remote, [ref_info.name]) == []
assert git_upstream.tmp_dir.scm.get_ref(str(ref_info)) == remote_rev

dvc.experiments.push(git_upstream.remote, [ref_info.name], force=True)
Expand Down Expand Up @@ -254,8 +252,7 @@ def test_pull_diverged(tmp_dir, scm, dvc, git_downstream, exp_stage):
git_downstream.tmp_dir.scm.set_ref(str(ref_info), remote_rev)

downstream_exp = git_downstream.tmp_dir.dvc.experiments
with pytest.raises(DvcException):
downstream_exp.pull(git_downstream.remote, ref_info.name)
assert downstream_exp.pull(git_downstream.remote, ref_info.name) == []
assert git_downstream.tmp_dir.scm.get_ref(str(ref_info)) == remote_rev

downstream_exp.pull(git_downstream.remote, ref_info.name, force=True)
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/repo/experiments/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_resolve_exp_ref(tmp_dir, scm, git_upstream, name_only, use_url):
assert str(result[name]) == ref
assert result["notexist"] is None

scm.push_refspec(git_upstream.url, ref, ref)
scm.push_refspecs(git_upstream.url, f"{ref}:{ref}")
remote = git_upstream.url if use_url else git_upstream.remote
name = "foo" if name_only else ref
remote_ref_info = resolve_name(scm, [name], remote)[name]
Expand Down