Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion scmrepo/git/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ def add_commit(
iter_refs = partialmethod(_backend_func, "iter_refs")
iter_remote_refs = partialmethod(_backend_func, "iter_remote_refs")
get_refs_containing = partialmethod(_backend_func, "get_refs_containing")
push_refspec = partialmethod(_backend_func, "push_refspec")
push_refspecs = partialmethod(_backend_func, "push_refspecs")
fetch_refspecs = partialmethod(_backend_func, "fetch_refspecs")
_stash_iter = partialmethod(_backend_func, "_stash_iter")
_stash_push = partialmethod(_backend_func, "_stash_push")
Expand Down
27 changes: 15 additions & 12 deletions scmrepo/git/backend/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from abc import ABC, abstractmethod
from enum import Enum
from typing import (
TYPE_CHECKING,
Callable,
Expand All @@ -25,6 +26,12 @@ def __init__(self, func):
super().__init__(f"No valid Git backend for '{func}'")


class SyncStatus(Enum):
SUCCESS = 0
UP_TO_DATE = 1
DIVERGED = 2


class BaseGitBackend(ABC):
"""Base Git backend class."""

Expand Down Expand Up @@ -206,25 +213,21 @@ def get_refs_containing(self, rev: str, pattern: Optional[str] = None):
"""Iterate over all git refs containing the specified revision."""

@abstractmethod
def push_refspec(
def push_refspecs(
self,
url: str,
src: Optional[str],
dest: str,
refspecs: Union[str, Iterable[str]],
force: bool = False,
on_diverged: Optional[Callable[[str, str], bool]] = None,
progress: Callable[["GitProgressEvent"], None] = None,
**kwargs,
):
) -> Mapping[str, SyncStatus]:
"""Push refspec to a remote Git repo.

Args:
url: Git remote name or absolute Git URL.
src: Local refspec. If src ends with "/" it will be treated as a
prefix, and all refs inside src will be pushed using dest
as destination refspec prefix. If src is None, dest will be
deleted from the remote.
dest: Remote refspec.
refspecs: Iterable containing refspecs to push.
Note that this will not match subkeys.
force: If True, remote refs will be overwritten.
on_diverged: Callback function which will be called if local ref
and remote have diverged and force is False. If the callback
Expand All @@ -237,12 +240,12 @@ def push_refspec(
def fetch_refspecs(
self,
url: str,
refspecs: Iterable[str],
force: Optional[bool] = False,
refspecs: Union[str, Iterable[str]],
force: bool = False,
on_diverged: Optional[Callable[[str, str], bool]] = None,
progress: Callable[["GitProgressEvent"], None] = None,
**kwargs,
):
) -> Mapping[str, SyncStatus]:
"""Fetch refspecs from a remote Git repo.

Args:
Expand Down
131 changes: 76 additions & 55 deletions scmrepo/git/backend/dulwich/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from scmrepo.utils import relpath

from ...objects import GitObject
from ..base import BaseGitBackend
from ..base import BaseGitBackend, SyncStatus

if TYPE_CHECKING:
from dulwich.repo import Repo
Expand Down Expand Up @@ -488,26 +488,24 @@ def iter_remote_refs(self, url: str, base: Optional[str] = None, **kwargs):
def get_refs_containing(self, rev: str, pattern: Optional[str] = None):
raise NotImplementedError

def push_refspec(
def push_refspecs(
self,
url: str,
src: Optional[str],
dest: str,
refspecs: Union[str, Iterable[str]],
force: bool = False,
on_diverged: Optional[Callable[[str, str], bool]] = None,
progress: Callable[["GitProgressEvent"], None] = None,
**kwargs,
):
) -> Mapping[str, SyncStatus]:
from dulwich.client import HTTPUnauthorized, get_transport_and_path
from dulwich.errors import NotGitRepository, SendPackError
from dulwich.objectspec import parse_reftuples
from dulwich.porcelain import (
DivergedBranches,
check_diverged,
get_remote_repo,
)

dest_refs, values = self._push_dest_refs(src, dest)

try:
_remote, location = get_remote_repo(self.repo, url)
client, path = get_transport_and_path(location, **kwargs)
Expand All @@ -516,26 +514,45 @@ def push_refspec(
f"'{url}' is not a valid Git remote or URL"
) from exc

change_result = {}
selected_refs = []

def update_refs(refs):
from dulwich.objects import ZERO_SHA

selected_refs.extend(
parse_reftuples(self.repo.refs, refs, refspecs, force=force)
)
new_refs = {}
for ref, value in zip(dest_refs, values):
if ref in refs and value != ZERO_SHA:
local_sha = self.repo.refs[ref]
remote_sha = refs[ref]
for (lh, rh, _) in selected_refs:
refname = os.fsdecode(rh)
if rh in refs and lh is not None:
if refs[rh] == self.repo.refs[lh]:
change_result[refname] = SyncStatus.UP_TO_DATE
continue
try:
check_diverged(self.repo, remote_sha, local_sha)
check_diverged(self.repo, refs[rh], self.repo.refs[lh])
except DivergedBranches:
if not force:
overwrite = False
if on_diverged:
overwrite = on_diverged(
os.fsdecode(ref), os.fsdecode(remote_sha)
overwrite = (
on_diverged(
os.fsdecode(lh), os.fsdecode(refs[rh])
)
if on_diverged
else False
)
if not overwrite:
change_result[refname] = SyncStatus.DIVERGED
continue
new_refs[ref] = value

if lh is None:
value = ZERO_SHA
else:
value = self.repo.refs[lh]

new_refs[rh] = value
change_result[refname] = SyncStatus.SUCCESS

return new_refs

try:
Expand All @@ -548,38 +565,23 @@ def update_refs(refs):
),
)
except (NotGitRepository, SendPackError) as exc:
raise SCMError("Git failed to push '{src}' to '{url}'") from exc
src = [lh for (lh, _, _) in selected_refs]
raise SCMError(f"Git failed to push '{src}' to '{url}'") from exc
except HTTPUnauthorized:
raise AuthError(url)

def _push_dest_refs(
self, src: Optional[str], dest: str
) -> Tuple[Iterable[bytes], Iterable[bytes]]:
from dulwich.objects import ZERO_SHA

if src is not None and src.endswith("/"):
src_b = os.fsencode(src)
keys = self.repo.refs.subkeys(src_b)
values = [self.repo.refs[b"".join([src_b, key])] for key in keys]
dest_refs = [b"".join([os.fsencode(dest), key]) for key in keys]
else:
if src is None:
values = [ZERO_SHA]
else:
values = [self.repo.refs[os.fsencode(src)]]
dest_refs = [os.fsencode(dest)]
return dest_refs, values
return change_result

def fetch_refspecs(
self,
url: str,
refspecs: Iterable[str],
refspecs: Union[str, Iterable[str]],
force: Optional[bool] = False,
on_diverged: Optional[Callable[[str, str], bool]] = None,
progress: Callable[["GitProgressEvent"], None] = None,
**kwargs,
):
) -> Mapping[str, SyncStatus]:
from dulwich.client import get_transport_and_path
from dulwich.errors import NotGitRepository
from dulwich.objectspec import parse_reftuples
from dulwich.porcelain import (
DivergedBranches,
Expand All @@ -594,7 +596,7 @@ def determine_wants(remote_refs):
parse_reftuples(
remote_refs,
self.repo.refs,
[os.fsencode(refspec) for refspec in refspecs],
refspecs,
force=force,
)
)
Expand All @@ -612,28 +614,47 @@ def determine_wants(remote_refs):
f"'{url}' is not a valid Git remote or URL"
) from exc

fetch_result = client.fetch(
path,
self.repo,
progress=DulwichProgressReporter(progress) if progress else None,
determine_wants=determine_wants,
)
try:
fetch_result = client.fetch(
path,
self.repo,
progress=DulwichProgressReporter(progress)
if progress
else None,
determine_wants=determine_wants,
)
except NotGitRepository as exc:
raise SCMError(f"Git failed to fetch ref from '{url}'") from exc

result = {}

for (lh, rh, _) in fetch_refs:
try:
if rh in self.repo.refs:
refname = os.fsdecode(rh)
if rh in self.repo.refs:
if self.repo.refs[rh] == fetch_result.refs[lh]:
result[refname] = SyncStatus.UP_TO_DATE
continue
try:
check_diverged(
self.repo, self.repo.refs[rh], fetch_result.refs[lh]
)
except DivergedBranches:
if not force:
overwrite = False
if on_diverged:
overwrite = on_diverged(
os.fsdecode(rh), os.fsdecode(fetch_result.refs[lh])
except DivergedBranches:
if not force:
overwrite = (
on_diverged(
os.fsdecode(rh),
os.fsdecode(fetch_result.refs[lh]),
)
if on_diverged
else False
)
if not overwrite:
continue
if not overwrite:
result[refname] = SyncStatus.DIVERGED
continue

self.repo.refs[rh] = fetch_result.refs[lh]
result[refname] = SyncStatus.SUCCESS
return result

def _stash_iter(self, ref: str):
stash = self._get_stash(ref)
Expand Down
15 changes: 7 additions & 8 deletions scmrepo/git/backend/gitpython.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from scmrepo.utils import relpath

from ..objects import GitCommit, GitObject
from .base import BaseGitBackend
from .base import BaseGitBackend, SyncStatus

if TYPE_CHECKING:
from scmrepo.progress import GitProgressEvent
Expand Down Expand Up @@ -474,27 +474,26 @@ def get_refs_containing(self, rev: str, pattern: Optional[str] = None):
except GitCommandError:
pass

def push_refspec(
def push_refspecs(
self,
url: str,
src: Optional[str],
dest: str,
refspecs: Union[str, Iterable[str]],
force: bool = False,
on_diverged: Optional[Callable[[str, str], bool]] = None,
progress: Callable[["GitProgressEvent"], None] = None,
**kwargs,
):
) -> Mapping[str, SyncStatus]:
raise NotImplementedError

def fetch_refspecs(
self,
url: str,
refspecs: Iterable[str],
force: Optional[bool] = False,
refspecs: Union[str, Iterable[str]],
force: bool = False,
on_diverged: Optional[Callable[[str, str], bool]] = None,
progress: Callable[["GitProgressEvent"], None] = None,
**kwargs,
):
) -> Mapping[str, SyncStatus]:
raise NotImplementedError

def _stash_iter(self, ref: str):
Expand Down
15 changes: 7 additions & 8 deletions scmrepo/git/backend/pygit2.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from scmrepo.utils import relpath

from ..objects import GitCommit, GitObject
from .base import BaseGitBackend
from .base import BaseGitBackend, SyncStatus

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -414,27 +414,26 @@ def _contains(repo, ref, search_commit):
) and _contains(self.repo, ref, search_commit):
yield ref

def push_refspec(
def push_refspecs(
self,
url: str,
src: Optional[str],
dest: str,
refspecs: Union[str, Iterable[str]],
force: bool = False,
on_diverged: Optional[Callable[[str, str], bool]] = None,
progress: Callable[["GitProgressEvent"], None] = None,
**kwargs,
):
) -> Mapping[str, SyncStatus]:
raise NotImplementedError

def fetch_refspecs(
self,
url: str,
refspecs: Iterable[str],
force: Optional[bool] = False,
refspecs: Union[str, Iterable[str]],
force: bool = False,
on_diverged: Optional[Callable[[str, str], bool]] = None,
progress: Callable[["GitProgressEvent"], None] = None,
**kwargs,
):
) -> Mapping[str, SyncStatus]:
raise NotImplementedError

def _stash_iter(self, ref: str):
Expand Down
Loading