diff --git a/scmrepo/git/__init__.py b/scmrepo/git/__init__.py index 205cce3e..5747a70c 100644 --- a/scmrepo/git/__init__.py +++ b/scmrepo/git/__init__.py @@ -296,6 +296,7 @@ def add_commit( add = partialmethod(_backend_func, "add") commit = partialmethod(_backend_func, "commit") checkout = partialmethod(_backend_func, "checkout") + fetch = partialmethod(_backend_func, "fetch") pull = partialmethod(_backend_func, "pull") push = partialmethod(_backend_func, "push") branch = partialmethod(_backend_func, "branch") diff --git a/scmrepo/git/backend/base.py b/scmrepo/git/backend/base.py index 54c25a21..d5e23acc 100644 --- a/scmrepo/git/backend/base.py +++ b/scmrepo/git/backend/base.py @@ -82,6 +82,15 @@ def checkout( ): pass + @abstractmethod + def fetch( + self, + remote: Optional[str] = None, + force: bool = False, + unshallow: bool = False, + ): + pass + @abstractmethod def pull(self, **kwargs): pass diff --git a/scmrepo/git/backend/dulwich/__init__.py b/scmrepo/git/backend/dulwich/__init__.py index 0a01c498..07cb3efe 100644 --- a/scmrepo/git/backend/dulwich/__init__.py +++ b/scmrepo/git/backend/dulwich/__init__.py @@ -27,6 +27,8 @@ from ..base import BaseGitBackend if TYPE_CHECKING: + from dulwich.repo import Repo + from scmrepo.progress import GitProgressEvent from ...objects import GitCommit @@ -144,8 +146,9 @@ def close(self): def root_dir(self) -> str: return self.repo.path - @staticmethod + @classmethod def clone( + cls, url: str, to_path: str, shallow_branch: Optional[str] = None, @@ -176,12 +179,30 @@ def clone( depth = 0 else: depth = 1 - clone_from(depth=depth, branch=os.fsencode(shallow_branch)) + repo = clone_from( + depth=depth, branch=os.fsencode(shallow_branch) + ) else: - clone_from() + repo = clone_from() + cls._set_default_tracking_branch(repo) except Exception as exc: raise CloneError(url, to_path) from exc + @staticmethod + def _set_default_tracking_branch(repo: "Repo"): + from dulwich.refs import LOCAL_BRANCH_PREFIX, parse_symref_value + + try: + ref = parse_symref_value(repo.refs.read_ref(b"HEAD")) + except ValueError: + return + if ref.startswith(LOCAL_BRANCH_PREFIX): + branch = ref[len(LOCAL_BRANCH_PREFIX) :] + config = repo.get_config() + section = ("branch", os.fsencode(branch)) + config.set(section, b"remote", b"origin") + config.set(section, b"merge", ref) + @staticmethod def init(path: str, bare: bool = False) -> None: from dulwich.porcelain import init @@ -270,6 +291,23 @@ def checkout( ): raise NotImplementedError + def fetch( + self, + remote: Optional[str] = None, + force: bool = False, + unshallow: bool = False, + ): + from dulwich.porcelain import fetch + from dulwich.protocol import DEPTH_INFINITE + + remote_b = os.fsencode(remote) if remote else b"origin" + fetch( + self.repo, + remote_location=remote_b, + force=force, + depth=DEPTH_INFINITE if unshallow else None, + ) + def pull(self, **kwargs): raise NotImplementedError diff --git a/scmrepo/git/backend/gitpython.py b/scmrepo/git/backend/gitpython.py index 49151ddd..0a994134 100644 --- a/scmrepo/git/backend/gitpython.py +++ b/scmrepo/git/backend/gitpython.py @@ -246,6 +246,24 @@ def checkout( else: self.repo.git.checkout(branch, force=force, **kwargs) + def fetch( + self, + remote: Optional[str] = None, + force: bool = False, + unshallow: bool = False, + ): + if not remote: + remote = "origin" + kwargs = {} + if force: + kwargs["force"] = True + if unshallow: + kwargs["unshallow"] = True + infos = self.repo.remote(name=remote).fetch(**kwargs) + for info in infos: + if info.flags & info.ERROR: + raise SCMError(f"fetch failed: {info.note}") + def pull(self, **kwargs): infos = self.repo.remote().pull(**kwargs) for info in infos: diff --git a/scmrepo/git/backend/pygit2.py b/scmrepo/git/backend/pygit2.py index 83fa0081..063d32ea 100644 --- a/scmrepo/git/backend/pygit2.py +++ b/scmrepo/git/backend/pygit2.py @@ -215,6 +215,14 @@ def checkout( else: self.repo.set_head(commit.id) + def fetch( + self, + remote: Optional[str] = None, + force: bool = False, + unshallow: bool = False, + ): + raise NotImplementedError + def pull(self, **kwargs): raise NotImplementedError @@ -243,7 +251,14 @@ def is_dirty(self, untracked_files: bool = False) -> bool: raise NotImplementedError def active_branch(self) -> str: - raise NotImplementedError + if self.repo.head_is_detached: + raise SCMError("No active branch (detached HEAD)") + if self.repo.head_is_unborn: + # if HEAD points to a nonexistent branch we still return the + # branch name (without "refs/heads/" prefix) to match gitpython's + # behavior + return self.repo.references["HEAD"].target[11:] + return self.repo.head.shorthand def list_branches(self) -> Iterable[str]: raise NotImplementedError @@ -588,18 +603,31 @@ def merge( msg: Optional[str] = None, squash: bool = False, ) -> Optional[str]: - from pygit2 import GIT_RESET_MIXED, GitError + from pygit2 import ( + GIT_MERGE_ANALYSIS_FASTFORWARD, + GIT_MERGE_ANALYSIS_NONE, + GIT_MERGE_ANALYSIS_UNBORN, + GIT_MERGE_ANALYSIS_UP_TO_DATE, + GIT_MERGE_PREFERENCE_FASTFORWARD_ONLY, + GIT_MERGE_PREFERENCE_NO_FASTFORWARD, + GitError, + ) if commit and squash: raise SCMError("Cannot merge with 'squash' and 'commit'") - if commit and not msg: - raise SCMError("Merge commit message is required") - with self.release_odb_handles(): + self.repo.index.read(False) + obj, _ref = self.repo.resolve_refish(rev) + analysis, ff_pref = self.repo.merge_analysis(obj.id) + + if analysis == GIT_MERGE_ANALYSIS_NONE: + raise SCMError(f"'{rev}' cannot be merged into HEAD") + if analysis & GIT_MERGE_ANALYSIS_UP_TO_DATE: + return None + try: - self.repo.index.read(False) - self.repo.merge(rev) + self.repo.merge(obj.id) self.repo.index.write() except GitError as exc: raise SCMError("Merge failed") from exc @@ -607,18 +635,54 @@ def merge( if self.repo.index.conflicts: raise MergeConflictError("Merge contained conflicts") - if commit: - user = self.default_signature - tree = self.repo.index.write_tree() - merge_commit = self.repo.create_commit( - "HEAD", user, user, msg, tree, [self.repo.head.target, rev] - ) - return str(merge_commit) - if squash: - self.repo.reset(self.repo.head.target, GIT_RESET_MIXED) + try: + if not ( + squash or ff_pref & GIT_MERGE_PREFERENCE_NO_FASTFORWARD + ): + if analysis & GIT_MERGE_ANALYSIS_FASTFORWARD: + return self._merge_ff(rev, obj) + + if analysis & GIT_MERGE_ANALYSIS_UNBORN: + self.repo.set_head(obj.id) + return str(obj.id) + + if ff_pref & GIT_MERGE_PREFERENCE_FASTFORWARD_ONLY: + raise SCMError("Cannot fast-forward HEAD to '{rev}'") + + if commit: + if not msg: + raise SCMError("Merge commit message is required") + user = self.default_signature + tree = self.repo.index.write_tree() + merge_commit = self.repo.create_commit( + "HEAD", + user, + user, + msg, + tree, + [self.repo.head.target, obj.id], + ) + return str(merge_commit) + + # --squash merge: + # HEAD is not moved and merge changes stay in index + return None + finally: self.repo.state_cleanup() self.repo.index.write() - return None + + def _merge_ff(self, rev: str, obj) -> str: + if self.repo.head_is_detached: + self.repo.set_head(obj.id) + else: + branch = self.get_ref("HEAD", follow=False) + assert branch + self.set_ref( + branch, + str(obj.id), + message=f"merge {rev}: Fast-forward", + ) + return str(obj.id) def validate_git_remote(self, url: str, **kwargs): raise NotImplementedError diff --git a/tests/test_git.py b/tests/test_git.py index adfc3de7..3356ad42 100644 --- a/tests/test_git.py +++ b/tests/test_git.py @@ -897,3 +897,21 @@ def test_clone( target = Git(str(target_dir)) assert target.get_rev() == rev assert (target_dir / "foo").read_text() == "foo" + + +@pytest.mark.skip_git_backend("pygit2") +def test_fetch( + tmp_dir: TmpDir, scm: Git, git: Git, tmp_dir_factory: TempDirFactory +): + tmp_dir.gen("foo", "foo") + scm.add_commit("foo", message="init") + + target_dir = tmp_dir_factory.mktemp("git-clone") + git.clone(str(tmp_dir), (target_dir)) + target = Git(str(target_dir)) + + scm.add_commit("bar", message="update") + rev = scm.get_rev() + + target.fetch() + assert target.get_ref("refs/remotes/origin/master") == rev