From 4980cba073d61d2e515d5eb736b46471b6c2c584 Mon Sep 17 00:00:00 2001 From: Peter Rowlands Date: Thu, 20 Jul 2023 17:23:39 +0900 Subject: [PATCH] dulwich: support interactive SSH key passphrase prompt in asyncssh vendor --- .../git/backend/dulwich/asyncssh_vendor.py | 82 +++++++++++++++++++ 1 file changed, 82 insertions(+) diff --git a/src/scmrepo/git/backend/dulwich/asyncssh_vendor.py b/src/scmrepo/git/backend/dulwich/asyncssh_vendor.py index 68af1531..d532cab0 100644 --- a/src/scmrepo/git/backend/dulwich/asyncssh_vendor.py +++ b/src/scmrepo/git/backend/dulwich/asyncssh_vendor.py @@ -5,10 +5,12 @@ TYPE_CHECKING, Callable, Coroutine, + Dict, Iterator, List, Optional, Sequence, + cast, ) from asyncssh import SSHClient @@ -25,6 +27,7 @@ from asyncssh.connection import SSHClientConnection from asyncssh.misc import MaybeAwait from asyncssh.process import SSHClientProcess + from asyncssh.public_key import KeyPairListArg, SSHKey from asyncssh.stream import SSHReader @@ -136,6 +139,85 @@ def _process_public_key_ok_gh(self, _pkttype, _pktid, packet): class InteractiveSSHClient(SSHClient): + _conn: Optional["SSHClientConnection"] = None + _keys_to_try: Optional[List["FilePath"]] = None + _passphrases: Dict[str, str] = {} + + def connection_made(self, conn: "SSHClientConnection") -> None: + self._conn = conn + self._keys_to_try = None + + def connection_lost(self, exc: Optional[Exception]) -> None: + self._conn = None + + async def public_key_auth_requested( # pylint: disable=invalid-overridden-method + self, + ) -> Optional["KeyPairListArg"]: + from asyncssh.public_key import ( + KeyImportError, + SSHLocalKeyPair, + read_private_key, + read_public_key, + ) + + if os.environ.get("GIT_TERMINAL_PROMPT") == "0": + return None + + assert self._conn is not None + if self._keys_to_try is None: + self._keys_to_try = [] + options = self._conn._options # pylint: disable=protected-access + config = options.config + client_keys = cast(Sequence["FilePath"], config.get("IdentityFile", ())) + for key_to_load in client_keys: + try: + read_private_key(key_to_load, passphrase=options.passphrase) + except KeyImportError as exc: + if str(exc).startswith("Passphrase"): + self._keys_to_try.append(key_to_load) + + while self._keys_to_try: + key_to_load = self._keys_to_try.pop() + pubkey_to_load = str(key_to_load) + ".pub" + try: + key = await self._read_private_key_interactive(key_to_load) + except KeyImportError: + continue + try: + pubkey = read_public_key(pubkey_to_load) + except (OSError, KeyImportError): + pubkey = None + return SSHLocalKeyPair(key, pubkey) + return None + + async def _read_private_key_interactive(self, path: "FilePath") -> "SSHKey": + from getpass import getpass + + from asyncssh.public_key import ( + KeyEncryptionError, + KeyImportError, + read_private_key, + ) + + path = str(path) + passphrase = self._passphrases.get(path) + if passphrase: + return read_private_key(path, passphrase=passphrase) + + loop = asyncio.get_running_loop() + for _ in range(3): + passphrase = await loop.run_in_executor( + None, getpass, f"Enter passphrase for key '{path}': " + ) + if passphrase: + try: + key = read_private_key(path, passphrase=passphrase) + self._passphrases[path] = passphrase + return key + except (KeyImportError, KeyEncryptionError): + pass + raise KeyImportError("Incorrect passphrase") + def kbdint_auth_requested(self) -> "MaybeAwait[Optional[str]]": return ""