Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ module = [
"asyncssh.*",
"pygit2.*",
"pytest_docker.plugin",
"urllib3.*",
]
ignore_missing_imports = true

Expand Down
39 changes: 35 additions & 4 deletions src/scmrepo/git/backend/dulwich/asyncssh_vendor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""asyncssh SSH vendor for Dulwich."""
import asyncio
import os
from typing import (
TYPE_CHECKING,
Callable,
Expand All @@ -10,6 +11,7 @@
Sequence,
)

from asyncssh import SSHClient
from dulwich.client import SSHVendor

from scmrepo.asyn import BaseAsyncObject, sync_wrapper
Expand All @@ -18,8 +20,10 @@
if TYPE_CHECKING:
from pathlib import Path

from asyncssh.auth import KbdIntPrompts, KbdIntResponse
from asyncssh.config import ConfigPaths, FilePath
from asyncssh.connection import SSHClientConnection
from asyncssh.misc import MaybeAwait
from asyncssh.process import SSHClientProcess
from asyncssh.stream import SSHReader

Expand Down Expand Up @@ -131,6 +135,36 @@ def _process_public_key_ok_gh(self, _pkttype, _pktid, packet):
return True


class InteractiveSSHClient(SSHClient):
def kbdint_auth_requested(self) -> "MaybeAwait[Optional[str]]":
return ""

async def kbdint_challenge_received( # pylint: disable=invalid-overridden-method
self,
name: str,
instructions: str,
lang: str,
prompts: "KbdIntPrompts",
) -> Optional["KbdIntResponse"]:
from getpass import getpass

if os.environ.get("GIT_TERMINAL_PROMPT") == "0":
return None

def _getpass(prompt: str) -> str:
return getpass(prompt=prompt).rstrip()

if instructions:
print(instructions)
loop = asyncio.get_running_loop()
return [
await loop.run_in_executor(
None, _getpass, f"({name}) {prompt}" if name else prompt
)
for prompt, _ in prompts
]


class AsyncSSHVendor(BaseAsyncObject, SSHVendor):
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
Expand Down Expand Up @@ -176,6 +210,7 @@ async def _run_command(
ignore_encrypted=not key_filename,
known_hosts=None,
encoding=None,
client_factory=InteractiveSSHClient,
)
proc = await conn.create_process(command, encoding=None)
except asyncssh.misc.PermissionDenied as exc:
Expand All @@ -185,10 +220,6 @@ async def _run_command(
run_command = sync_wrapper(_run_command)


# class ValidatedSSHClientConfig(SSHClientConfig):
# pass


def get_unsupported_opts(config_paths: "ConfigPaths") -> Iterator[str]:
from pathlib import Path, PurePath

Expand Down
52 changes: 48 additions & 4 deletions src/scmrepo/git/backend/dulwich/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Optional
import os
from typing import Any, Dict, Optional

from dulwich.client import Urllib3HttpGitClient
from dulwich.client import HTTPUnauthorized, Urllib3HttpGitClient

from scmrepo.git.credentials import Credential, CredentialNotFoundError

Expand Down Expand Up @@ -37,8 +38,51 @@ def __init__(
self.pool_manager.headers.update(basic_auth)
self._store_credentials = creds

def _http_request(self, *args, **kwargs):
result = super()._http_request(*args, **kwargs)
def _http_request(
self,
url: str,
headers: Optional[Dict[str, str]] = None,
data: Any = None,
):
try:
result = super()._http_request(url, headers=headers, data=data)
except HTTPUnauthorized:
auth_header = self._get_auth()
if not auth_header:
raise
if headers:
headers.update(auth_header)
else:
headers = auth_header
result = super()._http_request(url, headers=headers, data=data)
if self._store_credentials is not None:
self._store_credentials.approve()
return result

def _get_auth(self) -> Dict[str, str]:
from getpass import getpass

from urllib3.util import make_headers

try:
creds = Credential(username=self._username, url=self._base_url).fill()
self._store_credentials = creds
return make_headers(basic_auth=f"{creds.username}:{creds.password}")
except CredentialNotFoundError:
pass

if os.environ.get("GIT_TERMINAL_PROMPT") == "0":
return {}

try:
if self._username:
username = self._username
else:
username = input(f"Username for '{self._base_url}': ")
if self._password:
password = self._password
else:
password = getpass(f"Password for '{self._base_url}': ")
return make_headers(basic_auth=f"{username}:{password}")
except KeyboardInterrupt:
return {}
44 changes: 37 additions & 7 deletions src/scmrepo/git/backend/pygit2/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os
from contextlib import AbstractContextManager
from typing import TYPE_CHECKING, Callable, Optional, Union
from types import TracebackType
from typing import TYPE_CHECKING, Callable, Optional, Type, Union

from pygit2 import RemoteCallbacks as _RemoteCallbacks

Expand All @@ -20,14 +22,21 @@ def __init__(
self,
*args,
progress: Optional[Callable[["GitProgressEvent"], None]] = None,
**kwargs
**kwargs,
):
super().__init__(*args, **kwargs)
self.progress = GitProgressReporter(progress) if progress else None
self._store_credentials: Optional["Credential"] = None
self._tried_credentials = False

def __exit__(self, *args, **kwargs):
self._approve_credentials()
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
):
if exc_type is None:
self._approve_credentials()

def sideband_progress(self, string: str):
if self.progress is not None:
Expand All @@ -36,16 +45,37 @@ def sideband_progress(self, string: str):
def credentials(
self, url: str, username_from_url: Optional[str], allowed_types: int
) -> "_Pygit2Credential":
from pygit2 import Passthrough
from getpass import getpass

from pygit2 import GitError, Passthrough
from pygit2.credentials import GIT_CREDENTIAL_USERPASS_PLAINTEXT, UserPass

if self._tried_credentials:
raise GitError(f"authentication failed for '{url}'")
self._tried_credentials = True

if allowed_types & GIT_CREDENTIAL_USERPASS_PLAINTEXT:
try:
creds = Credential(username=username_from_url, url=url).fill()
self._store_credentials = creds
if self._store_credentials:
creds = self._store_credentials
else:
Credential(username=username_from_url, url=url).fill()
self._store_credentials = creds
return UserPass(creds.username, creds.password)
except CredentialNotFoundError:
pass

if os.environ.get("GIT_TERMINAL_PROMPT") != "0":
try:
if username_from_url:
username = username_from_url
else:
username = input(f"Username for '{url}': ")
password = getpass(f"Password for '{url}': ")
if username and password:
return UserPass(username, password)
except KeyboardInterrupt:
pass
raise Passthrough

def _approve_credentials(self):
Expand Down