From 2dbf321b73b0f4dd4ed3045deb3853ff8a7d0dae Mon Sep 17 00:00:00 2001 From: David de la Iglesia Castro Date: Fri, 3 Mar 2023 12:57:49 +0100 Subject: [PATCH] pygit2: Raise NotImplementedError on ssh remotes. --- src/scmrepo/git/__init__.py | 14 ++++---------- src/scmrepo/git/backend/pygit2.py | 4 +++- tests/test_pygit2.py | 15 +++++++++++++++ 3 files changed, 22 insertions(+), 11 deletions(-) diff --git a/src/scmrepo/git/__init__.py b/src/scmrepo/git/__init__.py index 1cefdc6c..e86af4d5 100644 --- a/src/scmrepo/git/__init__.py +++ b/src/scmrepo/git/__init__.py @@ -327,18 +327,12 @@ def fetch_refspecs( progress: Optional[Callable[["GitProgressEvent"], None]] = None, **kwargs, ) -> typing.Mapping[str, SyncStatus]: - from urllib.parse import urlparse - from .credentials import get_matching_helper_commands - if "dulwich" in kwargs.get("backends", self.backends.backends): - credentials_helper = any( - get_matching_helper_commands(url, self.dulwich.repo.get_config_stack()) - ) - parsed = urlparse(url) - ssh = parsed.scheme in ("git", "git+ssh", "ssh") or url.startswith("git@") - if credentials_helper or ssh: - kwargs["backends"] = ["dulwich"] + if "dulwich" in kwargs.get("backends", self.backends.backends) and any( + get_matching_helper_commands(url, self.dulwich.repo.get_config_stack()) + ): + kwargs["backends"] = ["dulwich"] return self._fetch_refspecs( url, diff --git a/src/scmrepo/git/backend/pygit2.py b/src/scmrepo/git/backend/pygit2.py index b6b0978c..05b5c0be 100644 --- a/src/scmrepo/git/backend/pygit2.py +++ b/src/scmrepo/git/backend/pygit2.py @@ -16,6 +16,7 @@ Tuple, Union, ) +from urllib.parse import urlparse from funcy import cached_property, reraise from shortuuid import uuid @@ -460,7 +461,8 @@ def get_remote(self, url: str) -> Generator["Remote", None, None]: except KeyError: raise SCMError(f"'{url}' is not a valid Git remote or URL") - if os.name == "nt" and url.startswith("ssh://"): + parsed = urlparse(url) + if parsed.scheme in ("git", "git+ssh", "ssh") or url.startswith("git@"): raise NotImplementedError if os.name == "nt" and url.startswith("file://"): url = url[len("file://") :] diff --git a/tests/test_pygit2.py b/tests/test_pygit2.py index 41f9f793..fa0ca63f 100644 --- a/tests/test_pygit2.py +++ b/tests/test_pygit2.py @@ -1,3 +1,4 @@ +# pylint: disable=unused-argument import pygit2 import pytest from pytest_mock import MockerFixture @@ -58,3 +59,17 @@ def test_pygit_stash_apply_conflicts( strategy=expected_strategy, reinstate_index=False, ) + + +@pytest.mark.parametrize( + "url", + [ + "git@github.com:iterative/scmrepo.git", + "ssh://login@server.com:12345/repository.git", + ], +) +def test_pygit2_ssh_error(tmp_dir: TmpDir, scm: Git, url): + backend = Pygit2Backend(tmp_dir) + with pytest.raises(NotImplementedError): + with backend.get_remote(url): + pass