Skip to content

Commit

Permalink
Fix error: identify file and passphrase was not copied to new SSHAuth…
Browse files Browse the repository at this point in the history
… instance

* changed type for identify file:
  internally always cast to Collection[str] to reduce memory usage
  (paramiko uses as Iterable[str])
* fixed pylint warnings
  • Loading branch information
penguinolog committed Aug 24, 2020
1 parent 3c3a63f commit 5bdee86
Show file tree
Hide file tree
Showing 10 changed files with 52 additions and 41 deletions.
10 changes: 5 additions & 5 deletions doc/source/SSHClient.rst
Expand Up @@ -10,7 +10,7 @@ API: SSHClient and SSHAuth.
SSHClient helper.

.. py:method:: __init__(host, port=22, username=None, password=None, *, private_keys=None, auth=None, verbose=True, ssh_config=None, ssh_auth_map=None, sock=None, keepalive=1)
.. py:method:: __init__(host, port=22, username=None, password=None, *, auth=None, verbose=True, ssh_config=None, ssh_auth_map=None, sock=None, keepalive=1)
:param host: remote hostname
:type host: ``str``
Expand Down Expand Up @@ -526,7 +526,7 @@ API: SSHClient and SSHAuth.
:param keys: Alternate connection keys
:type keys: ``Optional[Sequence[paramiko.RSAKey]]``
:param key_filename: filename(s) for additional key files
:type key_filename: ``Union[List[str], str, None]``
:type key_filename: ``Union[Iterable[str], str, None]``
:param passphrase: passphrase for keys. Need, if differs from password
:type passphrase: ``Optional[str]``

Expand All @@ -544,7 +544,7 @@ API: SSHClient and SSHAuth.

.. py:attribute:: key_filename
``Union[List[str], str, None]``
``Collection[str]``
Key filename(s).

.. versionadded:: 1.0.0
Expand Down Expand Up @@ -670,7 +670,7 @@ API: SSHClient and SSHAuth.
:param user: remote user
:type user: ``Optional[str]``
:param identityfile: connection ssh keys file names
:type identityfile: ``Optional[List[str]]``
:type identityfile: ``Optional[Collection[str]]``
:param proxycommand: proxy command for ssh connection
:type proxycommand: ``Optional[str]``
:param proxyjump: proxy host name
Expand Down Expand Up @@ -721,7 +721,7 @@ API: SSHClient and SSHAuth.

.. py:attribute:: identityfile
``Optional[List[str]]``
``Collection[str]``
Connection ssh keys file names.

.. py:attribute:: proxycommand
Expand Down
20 changes: 11 additions & 9 deletions exec_helpers/_ssh_helpers.py
Expand Up @@ -13,7 +13,7 @@
# noinspection PyPackageRequirements
import logwrap

SSHConfigDictLikeT = typing.Dict[str, typing.Union[str, int, bool, typing.List[str]]]
SSHConfigDictLikeT = typing.Dict[str, typing.Union[str, int, bool, typing.Collection[str]]]
SSHConfigsDictT = typing.Dict[str, SSHConfigDictLikeT]


Expand Down Expand Up @@ -62,7 +62,7 @@ def __init__(
hostname: str,
port: "typing.Optional[typing.Union[str, int]]" = None,
user: "typing.Optional[str]" = None,
identityfile: "typing.Optional[typing.List[str]]" = None,
identityfile: "typing.Optional[typing.Collection[str]]" = None,
proxycommand: "typing.Optional[str]" = None,
proxyjump: "typing.Optional[str]" = None,
*,
Expand All @@ -78,7 +78,7 @@ def __init__(
:param user: remote user
:type user: typing.Optional[str]
:param identityfile: connection ssh keys file names
:type identityfile: typing.Optional[typing.List[str]]
:type identityfile: typing.Optional[typing.Collection[str]]
:param proxycommand: proxy command for ssh connection
:type proxycommand: typing.Optional[str]
:type proxyjump: typing.Optional[str]
Expand All @@ -97,7 +97,7 @@ def __init__(
raise ValueError(f"port {self.__port} if not in range [1, 65535], which is incorrect.")

self.__user: "typing.Optional[str]" = user
self.__identityfile: "typing.Optional[typing.List[str]]" = identityfile
self.__identityfile: "typing.Optional[typing.Collection[str]]" = identityfile

if proxycommand and proxyjump:
raise ValueError(
Expand Down Expand Up @@ -292,7 +292,7 @@ def __eq__(
)
)
if isinstance(other, dict):
return self.as_dict == other
return self == self.from_ssh_config(other)
return NotImplemented

@property
Expand Down Expand Up @@ -323,15 +323,17 @@ def user(self) -> "typing.Optional[str]":
return self.__user

@property
def identityfile(self) -> "typing.Optional[typing.List[str]]":
def identityfile(self) -> "typing.Collection[str]":
"""Connection ssh keys file names.
:return: list of ssh private keys names
:rtype: typing.Optional[typing.List[str]]
:rtype: typing.Collection[str]
"""
if self.__identityfile is None:
return None
return self.__identityfile.copy()
return ()
if isinstance(self.__identityfile, str):
return (self.__identityfile,)
return tuple(self.__identityfile)

@property
def proxycommand(self) -> "typing.Optional[str]":
Expand Down
6 changes: 3 additions & 3 deletions exec_helpers/async_api/subprocess.py
Expand Up @@ -93,7 +93,7 @@ def stdout(self) -> "typing.Optional[typing.AsyncIterable[bytes]]": # type: ign
:return: STDOUT interface
:rtype: typing.Optional[typing.AsyncIterable[bytes]]
"""
return super(SubprocessExecuteAsyncResult, self).stdout # type: ignore
return super().stdout # type: ignore


class Subprocess(api.ExecHelper):
Expand Down Expand Up @@ -193,12 +193,12 @@ async def poll_stderr() -> None:
exit_code: int = await asyncio.wait_for(async_result.interface.wait(), timeout=timeout)
result.exit_code = exit_code
return result
except asyncio.TimeoutError:
except asyncio.TimeoutError as exc:
# kill -9 for all subprocesses
_subprocess_helpers.kill_proc_tree(async_result.interface.pid)
exit_signal: "typing.Optional[int]" = await asyncio.wait_for(async_result.interface.wait(), timeout=0.001)
if exit_signal is None:
raise exceptions.ExecHelperNoKillError(result=result, timeout=timeout) # type: ignore
raise exceptions.ExecHelperNoKillError(result=result, timeout=timeout) from exc # type: ignore
result.exit_code = exit_signal
finally:
stdout_task.cancel()
Expand Down
2 changes: 1 addition & 1 deletion exec_helpers/exceptions.py
Expand Up @@ -247,7 +247,7 @@ def __init__(
f"Got:\n"
f"\t{errors_str}"
)
super(ParallelCallProcessError, self).__init__(message)
super().__init__(message)
self.cmd: str = command
self.errors: "typing.Dict[typing.Tuple[str, int], exec_result.ExecResult]" = errors
self.results: "typing.Dict[typing.Tuple[str, int], exec_result.ExecResult]" = results
Expand Down
26 changes: 18 additions & 8 deletions exec_helpers/ssh_auth.py
Expand Up @@ -44,7 +44,7 @@ def __init__(
password: "typing.Optional[str]" = None,
key: "typing.Optional[paramiko.PKey]" = None,
keys: "typing.Optional[typing.Sequence[paramiko.PKey]]" = None,
key_filename: "typing.Union[typing.List[str], str, None]" = None,
key_filename: "typing.Union[typing.Iterable[str], str, None]" = None,
passphrase: "typing.Optional[str]" = None,
) -> None:
"""SSH credentials object.
Expand All @@ -62,7 +62,7 @@ def __init__(
:param keys: Alternate connection keys
:type keys: typing.Optional[typing.Sequence[paramiko.PKey]]]
:param key_filename: filename(s) for additional key files
:type key_filename: typing.Union[typing.List[str], str, None]
:type key_filename: typing.Union[typing.Iterable[str], str, None]
:param passphrase: passphrase for keys. Need, if differs from password
:type passphrase: "typing.Optional[str]"
Expand All @@ -87,10 +87,12 @@ def __init__(

self.__key_index: int = 0

if key_filename is None or isinstance(key_filename, list):
self.__key_filename: "typing.Optional[typing.List[str]]" = key_filename
if key_filename is None:
self.__key_filename: "typing.Collection[str]" = ()
elif isinstance(key_filename, str):
self.__key_filename = (key_filename,)
else:
self.__key_filename = [key_filename]
self.__key_filename = tuple(key_filename)
self.__passphrase: "typing.Optional[str]" = passphrase

@property
Expand Down Expand Up @@ -125,13 +127,14 @@ def public_key(self) -> "typing.Optional[str]":
return self.__get_public_key(self.__keys[self.__key_index])

@property
def key_filename(self) -> "typing.Optional[typing.List[str]]":
def key_filename(self) -> "typing.Collection[str]":
"""Key filename(s).
:return: copy of used key filename (original should not be changed via mutability).
.. versionadded:: 1.0.0
.. versionchanged:: 7.0.5 changed type relying on paramiko sources
"""
return copy.deepcopy(self.__key_filename)
return self.__key_filename

def enter_password(self, tgt: typing.BinaryIO) -> None:
"""Enter password to STDIN.
Expand Down Expand Up @@ -255,6 +258,8 @@ def __deepcopy__(self, memo: typing.Any) -> "SSHAuth":
password=self.__password,
key=self.__keys[self.__key_index],
keys=copy.deepcopy(self.__keys),
key_filename=copy.deepcopy(self.key_filename),
passphrase=self.__passphrase,
)

def __copy__(self) -> "SSHAuth":
Expand All @@ -265,7 +270,12 @@ def __copy__(self) -> "SSHAuth":
"""
# noinspection PyTypeChecker
return self.__class__(
username=self.username, password=self.__password, key=self.__keys[self.__key_index], keys=self.__keys
username=self.username,
password=self.__password,
key=self.__keys[self.__key_index],
keys=self.__keys,
key_filename=self.key_filename,
passphrase=self.__passphrase,
)

def __repr__(self) -> str:
Expand Down
8 changes: 3 additions & 5 deletions exec_helpers/subprocess.py
Expand Up @@ -119,9 +119,7 @@ class Subprocess(api.ExecHelper):
def __init__(self, log_mask_re: LogMaskReT = None) -> None:
"""Subprocess helper with timeouts and lock-free FIFO."""
mod_name = "exec_helpers" if self.__module__.startswith("exec_helpers") else self.__module__
super(Subprocess, self).__init__(
logger=logging.getLogger(f"{mod_name}.{self.__class__.__name__}"), log_mask_re=log_mask_re
)
super().__init__(logger=logging.getLogger(f"{mod_name}.{self.__class__.__name__}"), log_mask_re=log_mask_re)

def __enter__(self) -> "Subprocess": # pylint: disable=useless-super-delegation
"""Get context manager.
Expand Down Expand Up @@ -201,12 +199,12 @@ def close_streams() -> None:
concurrent.futures.wait([stdout_future, stderr_future], timeout=0.1) # Minimal timeout to complete polling
result.exit_code = exit_code
return result
except subprocess.TimeoutExpired:
except subprocess.TimeoutExpired as exc:
# kill -9 for all subprocesses
_subprocess_helpers.kill_proc_tree(async_result.interface.pid)
exit_signal: "typing.Optional[int]" = async_result.interface.poll()
if exit_signal is None:
raise exceptions.ExecHelperNoKillError(result=result, timeout=timeout) # type: ignore
raise exceptions.ExecHelperNoKillError(result=result, timeout=timeout) from exc # type: ignore
result.exit_code = exit_signal
finally:
stdout_future.cancel()
Expand Down
8 changes: 4 additions & 4 deletions test/test_ssh_client_execute_through_host.py
Expand Up @@ -138,15 +138,15 @@ def test_01_execute_through_host_no_creds(
]
connect.assert_has_calls(
[
mock.call(hostname=host, password=password, pkey=None, port=port, username=username, key_filename=None),
mock.call(hostname=host, password=password, pkey=None, port=port, username=username, key_filename=()),
mock.call(
hostname=target,
port=port,
username=username,
password=password,
pkey=None,
sock=ssh_intermediate_channel(),
key_filename=None,
key_filename=(),
),
]
)
Expand All @@ -165,15 +165,15 @@ def test_02_execute_through_host_with_creds(
]
connect.assert_has_calls(
[
mock.call(hostname=host, password=password, pkey=None, port=port, username=username, key_filename=None),
mock.call(hostname=host, password=password, pkey=None, port=port, username=username, key_filename=()),
mock.call(
hostname=target,
port=port,
username=username_2,
password=password_2,
pkey=None,
sock=ssh_intermediate_channel(),
key_filename=None,
key_filename=(),
),
]
)
Expand Down
3 changes: 2 additions & 1 deletion test/test_ssh_client_init_basic.py
Expand Up @@ -107,7 +107,7 @@ def test_init_base(paramiko_ssh_client, auto_add_policy, run_parameters, ssh_aut
if auth is None:
expected_calls = [
_ssh.set_missing_host_key_policy("AutoAddPolicy"),
_ssh.connect(hostname=host, password=password, pkey=None, port=port, username=username, key_filename=None),
_ssh.connect(hostname=host, password=password, pkey=None, port=port, username=username, key_filename=()),
_ssh.get_transport(),
_ssh.get_transport().set_keepalive(1),
]
Expand All @@ -131,5 +131,6 @@ def test_init_base(paramiko_ssh_client, auto_add_policy, run_parameters, ssh_aut
if ssh.auth.username:
expected_config_dict[host]["user"] = ssh.auth.username

assert ssh.ssh_config[host] == expected_config_dict[host]
assert ssh.ssh_config == expected_config_dict
assert ssh.ssh_config[host].hostname == host
6 changes: 3 additions & 3 deletions test/test_ssh_client_init_special.py
Expand Up @@ -78,7 +78,7 @@ def test_001_require_key(paramiko_ssh_client, auto_add_policy, ssh_auth_logger):

pkey = private_keys[0]

kwargs_no_key = dict(hostname=host, pkey=None, port=port, username=username, password=None, key_filename=None)
kwargs_no_key = dict(hostname=host, pkey=None, port=port, username=username, password=None, key_filename=())
kwargs_full = {key: kwargs_no_key[key] for key in kwargs_no_key}
kwargs_full["pkey"] = pkey

Expand Down Expand Up @@ -113,7 +113,7 @@ def test_002_use_next_key(paramiko_ssh_client, auto_add_policy, ssh_auth_logger)

ssh_auth_logger.debug.assert_called_once_with(f"Main key has been updated, public key is: \n{ssh.auth.public_key}")

kwargs_no_key = dict(hostname=host, pkey=None, port=port, username=username, password=None, key_filename=None)
kwargs_no_key = dict(hostname=host, pkey=None, port=port, username=username, password=None, key_filename=())
kwargs_key_0 = {key: kwargs_no_key[key] for key in kwargs_no_key}
kwargs_key_0["pkey"] = private_keys[0]
kwargs_key_1 = {key: kwargs_no_key[key] for key in kwargs_no_key}
Expand Down Expand Up @@ -320,7 +320,7 @@ def test_012_re_connect(paramiko_ssh_client, auto_add_policy, ssh_auth_logger):
_ssh.close(),
_ssh,
_ssh.set_missing_host_key_policy("AutoAddPolicy"),
_ssh.connect(hostname="127.0.0.1", password=None, pkey=None, port=22, username=None, key_filename=None),
_ssh.connect(hostname="127.0.0.1", password=None, pkey=None, port=22, username=None, key_filename=()),
_ssh.get_transport(),
_ssh.get_transport().set_keepalive(1),
]
Expand Down
4 changes: 2 additions & 2 deletions test/test_ssh_config.py
Expand Up @@ -103,7 +103,7 @@ def test_no_configs(no_system_ssh_config, no_user_ssh_config):
assert host_config.port is None

assert host_config.user is None
assert host_config.identityfile is None
assert host_config.identityfile == ()

assert host_config.proxycommand is None
assert host_config.proxyjump is None
Expand All @@ -125,7 +125,7 @@ def test_simple_config(system_ssh_config, user_ssh_config):
assert host_config.port == PORT

assert host_config.user == USER
assert host_config.identityfile == IDENTIFY_FILES
assert host_config.identityfile == tuple(IDENTIFY_FILES)

assert host_config.controlpath == f"~/.ssh/.control-{USER}@{HOST}:{PORT}"
assert not host_config.controlmaster # auto => False
Expand Down

0 comments on commit 5bdee86

Please sign in to comment.