Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions src/scmrepo/git/backend/dulwich/asyncssh_vendor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
TYPE_CHECKING,
Callable,
Coroutine,
Dict,
Iterator,
List,
Optional,
Sequence,
cast,
)

from asyncssh import SSHClient
Expand All @@ -25,6 +27,7 @@
from asyncssh.connection import SSHClientConnection
from asyncssh.misc import MaybeAwait
from asyncssh.process import SSHClientProcess
from asyncssh.public_key import KeyPairListArg, SSHKey
from asyncssh.stream import SSHReader


Expand Down Expand Up @@ -136,6 +139,85 @@ def _process_public_key_ok_gh(self, _pkttype, _pktid, packet):


class InteractiveSSHClient(SSHClient):
_conn: Optional["SSHClientConnection"] = None
_keys_to_try: Optional[List["FilePath"]] = None
_passphrases: Dict[str, str] = {}

def connection_made(self, conn: "SSHClientConnection") -> None:
self._conn = conn
self._keys_to_try = None

def connection_lost(self, exc: Optional[Exception]) -> None:
self._conn = None

async def public_key_auth_requested( # pylint: disable=invalid-overridden-method
self,
) -> Optional["KeyPairListArg"]:
from asyncssh.public_key import (
KeyImportError,
SSHLocalKeyPair,
read_private_key,
read_public_key,
)

if os.environ.get("GIT_TERMINAL_PROMPT") == "0":
return None

assert self._conn is not None
if self._keys_to_try is None:
self._keys_to_try = []
options = self._conn._options # pylint: disable=protected-access
config = options.config
client_keys = cast(Sequence["FilePath"], config.get("IdentityFile", ()))
for key_to_load in client_keys:
try:
read_private_key(key_to_load, passphrase=options.passphrase)
except KeyImportError as exc:
if str(exc).startswith("Passphrase"):
self._keys_to_try.append(key_to_load)

while self._keys_to_try:
key_to_load = self._keys_to_try.pop()
pubkey_to_load = str(key_to_load) + ".pub"
try:
key = await self._read_private_key_interactive(key_to_load)
except KeyImportError:
continue
try:
pubkey = read_public_key(pubkey_to_load)
except (OSError, KeyImportError):
pubkey = None
return SSHLocalKeyPair(key, pubkey)
return None

async def _read_private_key_interactive(self, path: "FilePath") -> "SSHKey":
from getpass import getpass

from asyncssh.public_key import (
KeyEncryptionError,
KeyImportError,
read_private_key,
)

path = str(path)
passphrase = self._passphrases.get(path)
if passphrase:
return read_private_key(path, passphrase=passphrase)

loop = asyncio.get_running_loop()
for _ in range(3):
passphrase = await loop.run_in_executor(
None, getpass, f"Enter passphrase for key '{path}': "
)
if passphrase:
try:
key = read_private_key(path, passphrase=passphrase)
self._passphrases[path] = passphrase
return key
except (KeyImportError, KeyEncryptionError):
pass
raise KeyImportError("Incorrect passphrase")

def kbdint_auth_requested(self) -> "MaybeAwait[Optional[str]]":
return ""

Expand Down