Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/scmrepo/git/backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,9 @@ def fetch_refspecs(
returns True the local ref will be overwritten.
Callback will be of the form:
on_diverged(local_refname, remote_sha)

Returns:
Mapping of local_refname to sync status.
"""

@abstractmethod
Expand Down
100 changes: 64 additions & 36 deletions src/scmrepo/git/backend/pygit2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@


if TYPE_CHECKING:
from pygit2 import Signature
from pygit2 import Oid, Signature
from pygit2.remote import Remote # type: ignore
from pygit2.repository import Repository

Expand Down Expand Up @@ -551,7 +551,8 @@ def _merge_remote_branch(
raise SCMError("Unknown merge analysis result")

@contextmanager
def get_remote(self, url: str) -> Generator["Remote", None, None]:
def _get_remote(self, url: str) -> Generator["Remote", None, None]:
"""Return a pygit2.Remote suitable for the specified Git URL or remote name."""
try:
remote = self.repo.remotes[url]
url = remote.url
Expand All @@ -577,57 +578,84 @@ def fetch_refspecs(
progress: Callable[["GitProgressEvent"], None] = None,
**kwargs,
) -> Mapping[str, SyncStatus]:
import fnmatch

from pygit2 import GitError

from .callbacks import RemoteCallbacks

if isinstance(refspecs, str):
refspecs = [refspecs]
refspecs = self._refspecs_list(refspecs, force=force)

with self.get_remote(url) as remote:
fetch_refspecs: List[str] = []
for refspec in refspecs:
if ":" in refspec:
lh, rh = refspec.split(":")
else:
lh = rh = refspec
if not rh.startswith("refs/"):
rh = f"refs/heads/{rh}"
if not lh.startswith("refs/"):
lh = f"refs/heads/{lh}"
rh = rh[len("refs/") :]
refspec = f"+{lh}:refs/remotes/{remote.name}/{rh}"
fetch_refspecs.append(refspec)

logger.debug("fetch_refspecs: %s", fetch_refspecs)
# libgit2 rejects diverged refs but does not have a callback to notify
# when a ref was rejected so we have to determine whether no callback
# means up to date or rejected
def _default_status(
src: str, dst: str, remote_refs: Dict[str, "Oid"]
) -> SyncStatus:
try:
if remote_refs[src] != self.repo.references[dst].target:
return SyncStatus.DIVERGED
except KeyError:
# remote_refs lookup is skipped when force is set, refs cannot
# be diverged on force
pass
return SyncStatus.UP_TO_DATE

with self._get_remote(url) as remote:
with reraise(
GitError,
SCMError(f"Git failed to fetch ref from '{url}'"),
):
with RemoteCallbacks(progress=progress) as cb:
remote_refs: Dict[str, "Oid"] = (
{
head["name"]: head["oid"]
for head in remote.ls_remotes(callbacks=cb)
}
if not force
else {}
)
remote.fetch(
refspecs=fetch_refspecs,
refspecs=refspecs,
callbacks=cb,
message="fetch",
)

result: Dict[str, "SyncStatus"] = {}
for refspec in fetch_refspecs:
_, rh = refspec.split(":")
if not rh.endswith("*"):
refname = rh.split("/", 3)[-1]
refname = f"refs/{refname}"
result[refname] = self._merge_remote_branch(
rh, refname, force, on_diverged
)
continue
rh = rh.rstrip("*").rstrip("/") + "/"
for branch in self.iter_refs(base=rh):
refname = f"refs/{branch[len(rh):]}"
result[refname] = self._merge_remote_branch(
branch, refname, force, on_diverged
)
for refspec in refspecs:
lh, rh = refspec.split(":")
if lh.endswith("*"):
assert rh.endswith("*")
lh_prefix = lh[:-1]
rh_prefix = rh[:-1]
for refname in remote_refs:
if fnmatch.fnmatch(refname, lh):
src = refname
dst = f"{rh_prefix}{refname[len(lh_prefix):]}"
result[dst] = cb.result.get(
src, _default_status(src, dst, remote_refs)
)
else:
result[rh] = cb.result.get(lh, _default_status(lh, rh, remote_refs))

return result

@staticmethod
def _refspecs_list(
refspecs: Union[str, Iterable[str]],
force: bool = False,
) -> List[str]:
if isinstance(refspecs, str):
if force and not refspecs.startswith("+"):
refspecs = f"+{refspecs}"
return [refspecs]
if force:
return [
(refspec if refspec.startswith("+") else f"+{refspec}")
for refspec in refspecs
]
return list(refspecs)

def _stash_iter(self, ref: str):
raise NotImplementedError

Expand Down
11 changes: 10 additions & 1 deletion src/scmrepo/git/backend/pygit2/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from contextlib import AbstractContextManager
from types import TracebackType
from typing import TYPE_CHECKING, Callable, Optional, Type, Union
from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, Union

from pygit2 import RemoteCallbacks as _RemoteCallbacks

from scmrepo.git.backend.base import SyncStatus
from scmrepo.git.credentials import Credential, CredentialNotFoundError
from scmrepo.progress import GitProgressReporter

if TYPE_CHECKING:
from pygit2 import Oid
from pygit2.credentials import Keypair, Username, UserPass

from scmrepo.progress import GitProgressEvent
Expand All @@ -27,6 +29,7 @@ def __init__(
self.progress = GitProgressReporter(progress) if progress else None
self._store_credentials: Optional["Credential"] = None
self._tried_credentials = False
self.result: Dict[str, SyncStatus] = {}

def __exit__(
self,
Expand Down Expand Up @@ -66,3 +69,9 @@ def credentials(
def _approve_credentials(self):
if self._store_credentials:
self._store_credentials.approve()

def update_tips(self, refname: str, old: "Oid", new: "Oid"):
if old == new:
self.result[refname] = SyncStatus.UP_TO_DATE
else:
self.result[refname] = SyncStatus.SUCCESS
2 changes: 1 addition & 1 deletion tests/test_pygit2.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def test_pygit_stash_apply_conflicts(
def test_pygit_ssh_error(tmp_dir: TmpDir, scm: Git, url):
backend = Pygit2Backend(tmp_dir)
with pytest.raises(NotImplementedError):
with backend.get_remote(url):
with backend._get_remote(url): # pylint: disable=protected-access
pass


Expand Down