Skip to content

Commit

Permalink
Stop ssh connection caching. Fix #132 (#135)
Browse files Browse the repository at this point in the history
* Stop ssh connection caching. Fix #132
  • Loading branch information
penguinolog committed Jul 3, 2019
1 parent 970901e commit b9f469c
Show file tree
Hide file tree
Showing 14 changed files with 40 additions and 294 deletions.
8 changes: 2 additions & 6 deletions README.rst
Expand Up @@ -29,7 +29,7 @@ you can call command with timeout, but without receiving return code,
or call command and wait for return code, but without timeout processing.

In the most cases, we are need just simple SSH client with comfortable API for calls, calls via SSH proxy and checking return code/stderr.
This library offers this functionality with connection memorizing, deadlock free polling and friendly result objects
This library offers this functionality with deadlock free polling and friendly result objects
(with inline decoding of XML Element tree, YAML, JSON, binary or just strings).
In addition this library offers the same API for subprocess calls, but with specific limitation: no parallel calls
(for protection from race conditions).
Expand All @@ -38,7 +38,6 @@ Pros:

* STDOUT and STDERR polling during command execution - no deadlocks.
* The same API for subprocess and ssh.
* Connection memorize.
* Free software: Apache license
* Open Source: https://github.com/python-useful-helpers/exec-helpers
* PyPI packaged: https://pypi.python.org/pypi/exec-helpers
Expand All @@ -55,8 +54,7 @@ Pros:

This package includes:

* `SSHClient` - historically the first one helper, which used for SSH connections and requires memorization
due to impossibility of connection close prediction.
* `SSHClient` - historically the first one helper, which used for SSH connections.
Several API calls for sFTP also presents.

* `SSHAuth` - class for credentials storage. `SSHClient` does not store credentials as-is, but uses `SSHAuth` for it.
Expand Down Expand Up @@ -125,8 +123,6 @@ Passphrase is an alternate password for keys, if it differs from main password.
If main key now correct for username - alternate keys tried, if correct key found - it became main.
If no working key - password is used and None is set as main key.

.. note:: Automatic closing connections during cache record removal supported on CPython implementation only.

Context manager is available, connection is closed and lock is released on exit from context.

.. note:: context manager is strictly not recommended in scenarios with fast reconnect to the same host with te same credentials.
Expand Down
25 changes: 7 additions & 18 deletions doc/source/SSHClient.rst
Expand Up @@ -80,10 +80,6 @@ API: SSHClient and SSHAuth.
Close connection

.. py:classmethod:: close()
Close all memorized connections

.. py:method:: reconnect()
Reconnect SSH session
Expand Down Expand Up @@ -147,8 +143,7 @@ API: SSHClient and SSHAuth.
:type open_stderr: bool
:param verbose: produce verbose log record on command call
:type verbose: bool
:param log_mask_re: regex lookup rule to mask command for logger.
all MATCHED groups will be replaced by '<*masked*>'
:param log_mask_re: regex lookup rule to mask command for logger. all MATCHED groups will be replaced by '<*masked*>'
:type log_mask_re: ``typing.Optional[str]``
:param chroot_path: chroot path override
:type chroot_path: ``typing.Optional[str]``
Expand Down Expand Up @@ -176,8 +171,7 @@ API: SSHClient and SSHAuth.
:type verbose: ``bool``
:param timeout: Timeout for command execution.
:type timeout: ``typing.Union[int, float, None]``
:param log_mask_re: regex lookup rule to mask command for logger.
all MATCHED groups will be replaced by '<*masked*>'
:param log_mask_re: regex lookup rule to mask command for logger. all MATCHED groups will be replaced by '<*masked*>'
:type log_mask_re: ``typing.Optional[str]``
:param stdin: pass STDIN text to the process
:type stdin: ``typing.Union[bytes, str, bytearray, None]``
Expand All @@ -196,8 +190,7 @@ API: SSHClient and SSHAuth.
:type verbose: ``bool``
:param timeout: Timeout for command execution.
:type timeout: ``typing.Union[int, float, None]``
:param log_mask_re: regex lookup rule to mask command for logger.
all MATCHED groups will be replaced by '<*masked*>'
:param log_mask_re: regex lookup rule to mask command for logger. all MATCHED groups will be replaced by '<*masked*>'
:type log_mask_re: ``typing.Optional[str]``
:param stdin: pass STDIN text to the process
:type stdin: ``typing.Union[bytes, str, bytearray, None]``
Expand All @@ -222,8 +215,7 @@ API: SSHClient and SSHAuth.
:type expected: typing.Iterable[typing.Union[int, ExitCodes]]
:param raise_on_err: Raise exception on unexpected return code
:type raise_on_err: ``bool``
:param log_mask_re: regex lookup rule to mask command for logger.
all MATCHED groups will be replaced by '<*masked*>'
:param log_mask_re: regex lookup rule to mask command for logger. all MATCHED groups will be replaced by '<*masked*>'
:type log_mask_re: ``typing.Optional[str]``
:param stdin: pass STDIN text to the process
:type stdin: ``typing.Union[bytes, str, bytearray, None]``
Expand Down Expand Up @@ -253,8 +245,7 @@ API: SSHClient and SSHAuth.
:type raise_on_err: ``bool``
:param expected: expected return codes (0 by default)
:type expected: typing.Iterable[typing.Union[int, ExitCodes]]
:param log_mask_re: regex lookup rule to mask command for logger.
all MATCHED groups will be replaced by '<*masked*>'
:param log_mask_re: regex lookup rule to mask command for logger. all MATCHED groups will be replaced by '<*masked*>'
:type log_mask_re: ``typing.Optional[str]``
:param stdin: pass STDIN text to the process
:type stdin: ``typing.Union[bytes, str, bytearray, None]``
Expand Down Expand Up @@ -285,8 +276,7 @@ API: SSHClient and SSHAuth.
:type timeout: ``typing.Union[int, float, None]``
:param stdin: pass STDIN text to the process
:type stdin: typing.Union[bytes, str, bytearray, None]
:param log_mask_re: regex lookup rule to mask command for logger.
all MATCHED groups will be replaced by '<*masked*>'
:param log_mask_re: regex lookup rule to mask command for logger. all MATCHED groups will be replaced by '<*masked*>'
:type log_mask_re: typing.Optional[str]
:param get_pty: open PTY on target machine
:type get_pty: ``bool``
Expand Down Expand Up @@ -318,8 +308,7 @@ API: SSHClient and SSHAuth.
:type raise_on_err: ``bool``
:param stdin: pass STDIN text to the process
:type stdin: typing.Union[bytes, str, bytearray, None]
:param log_mask_re: regex lookup rule to mask command for logger.
all MATCHED groups will be replaced by '<*masked*>'
:param log_mask_re: regex lookup rule to mask command for logger. all MATCHED groups will be replaced by '<*masked*>'
:type log_mask_re: typing.Optional[str]
:param exception_class: Exception to raise on error. Mandatory subclass of ParallelCallProcessError
:type exception_class: typing.Type[ParallelCallProcessError]
Expand Down
146 changes: 5 additions & 141 deletions exec_helpers/_ssh_client_base.py
Expand Up @@ -19,21 +19,16 @@
__all__ = ("SSHClientBase", "SshExecuteAsyncResult")

# Standard Library
import abc
import base64
import concurrent.futures
import copy
import datetime
import logging
import platform
import stat
import sys
import time
import typing
import warnings

# External Dependencies
import advanced_descriptors
import paramiko # type: ignore
import tenacity # type: ignore
import threaded
Expand Down Expand Up @@ -92,123 +87,6 @@ def stdout(self) -> typing.Optional[paramiko.ChannelFile]: # type: ignore
return super(SshExecuteAsyncResult, self).stdout


CPYTHON = "CPython" == platform.python_implementation()


class _MemorizedSSH(abc.ABCMeta):
"""Memorize metaclass for SSHClient.
This class implements caching and managing of SSHClient connections.
Class is not in public scope: all required interfaces is accessible throw
SSHClient classmethods.
Main flow is:
SSHClient() -> check for cached connection and
- If exists the same: check for alive, reconnect if required and return
- If exists with different credentials: delete and continue processing
create new connection and cache on success
* Note: each invocation of SSHClient instance will return current dir to
the root of the current user home dir ("cd ~").
It is necessary to avoid unpredictable behavior when the same
connection is used from different places.
If you need to enter some directory and execute command there, please
use the following approach:
cmd1 = "cd <some dir> && <command1>"
cmd2 = "cd <some dir> && <command2>"
Close cached connections is allowed per-client and all stored:
connection will be closed, but still stored in cache for faster reconnect
Clear cache is strictly not recommended:
from this moment all open connections should be managed manually,
duplicates is possible.
"""

__cache: typing.Dict[typing.Tuple[str, int], "SSHClientBase"] = {}

def __call__( # type: ignore
cls: "_MemorizedSSH",
host: str,
port: int = 22,
username: typing.Optional[str] = None,
password: typing.Optional[str] = None,
private_keys: typing.Optional[typing.Iterable[paramiko.RSAKey]] = None,
auth: typing.Optional[ssh_auth.SSHAuth] = None,
verbose: bool = True,
) -> "SSHClientBase":
"""Main memorize method: check for cached instance and return it. API follows target __init__.
:param host: remote hostname
:type host: str
:param port: remote ssh port
:type port: int
:param username: remote username.
:type username: typing.Optional[str]
:param password: remote password
:type password: typing.Optional[str]
:param private_keys: private keys for connection
:type private_keys: typing.Optional[typing.Iterable[paramiko.RSAKey]]
:param auth: credentials for connection
:type auth: typing.Optional[ssh_auth.SSHAuth]
:param verbose: show additional error/warning messages
:type verbose: bool
:return: SSH client instance
:rtype: SSHClientBase
"""
if (host, port) in cls.__cache:
key = host, port
if auth is None:
auth = ssh_auth.SSHAuth(username=username, password=password, keys=private_keys)
if hash((cls, host, port, auth)) == hash(cls.__cache[key]):
ssh: "SSHClientBase" = cls.__cache[key]
# noinspection PyBroadException
try:
ssh.execute("cd ~", timeout=5)
except BaseException: # Note: Do not change to lower level!
ssh.logger.debug("Reconnect")
ssh.reconnect()
return ssh
if CPYTHON and sys.getrefcount(cls.__cache[key]) == 2: # pragma: no cover
# If we have only cache reference and temporary getrefcount
# reference: close connection before deletion
cls.__cache[key].logger.debug("Closing as unused")
cls.__cache[key].close() # type: ignore
del cls.__cache[key]
# noinspection PyArgumentList
ssh = super(_MemorizedSSH, cls).__call__(
host=host,
port=port,
username=username,
password=password,
private_keys=private_keys,
auth=auth,
verbose=verbose,
)
cls.__cache[(ssh.hostname, ssh.port)] = ssh
return ssh

@classmethod
def clear_cache(mcs: typing.Type["_MemorizedSSH"]) -> None:
"""Clear cached connections for initialize new instance on next call.
getrefcount is used to check for usage, so connections closed on CPYTHON only.
"""
n_count = 3
# PY3: cache, ssh, temporary
for ssh in mcs.__cache.values():
if CPYTHON and sys.getrefcount(ssh) == n_count: # pragma: no cover
ssh.logger.debug("Closing as unused")
ssh.close() # type: ignore
mcs.__cache = {}

@classmethod
def close_connections(mcs: typing.Type["_MemorizedSSH"]) -> None:
"""Close connections for selected or all cached records."""
for ssh in mcs.__cache.values():
if ssh.is_alive:
ssh.close() # type: ignore


class _SudoContext:
"""Context manager for call commands with sudo."""

Expand Down Expand Up @@ -265,7 +143,7 @@ def __exit__(self, exc_type: typing.Any, exc_val: typing.Any, exc_tb: typing.Any
self.__ssh.keepalive_mode = self.__keepalive_status


class SSHClientBase(api.ExecHelper, metaclass=_MemorizedSSH):
class SSHClientBase(api.ExecHelper):
"""SSH Client helper."""

__slots__ = ("__hostname", "__port", "__auth", "__ssh", "__sftp", "__sudo_mode", "__keepalive_mode", "__verbose")
Expand Down Expand Up @@ -415,7 +293,6 @@ def _sftp(self) -> paramiko.sftp_client.SFTPClient:
return self.__sftp
raise paramiko.SSHException("SFTP connection failed")

@advanced_descriptors.SeparateClassMethod
def close(self) -> None:
"""Close SSH and SFTP sessions."""
with self.lock:
Expand All @@ -432,19 +309,6 @@ def close(self) -> None:
except Exception:
self.logger.exception("Could not close sftp connection")

# noinspection PyMethodParameters
@close.class_method # type: ignore
def close(cls: typing.Type["SSHClientBase"]) -> None: # pylint: disable=no-self-argument
"""Close all memorized SSH and SFTP sessions."""
# noinspection PyUnresolvedReferences
cls.__class__.close_connections()

@classmethod
def _clear_cache(cls: typing.Type["SSHClientBase"]) -> None:
"""Enforce clear memorized records."""
warnings.warn("_clear_cache() is dangerous and not recommended for normal use!", Warning)
_MemorizedSSH.clear_cache()

def __del__(self) -> None:
"""Destructor helper: close channel and threads BEFORE closing others.
Expand All @@ -465,7 +329,7 @@ def __exit__(self, exc_type: typing.Any, exc_val: typing.Any, exc_tb: typing.Any
.. versionchanged:: 1.2.1 disconnect enforced on close only not in keepalive mode
"""
if not self.__keepalive_mode:
self.close() # type: ignore
self.close()
super(SSHClientBase, self).__exit__(exc_type, exc_val, exc_tb)

@property
Expand Down Expand Up @@ -505,7 +369,7 @@ def keepalive_mode(self, mode: bool) -> None:
def reconnect(self) -> None:
"""Reconnect SSH session."""
with self.lock:
self.close() # type: ignore
self.close()

self.__ssh = paramiko.SSHClient()
self.__ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
Expand Down Expand Up @@ -684,7 +548,7 @@ def poll_streams() -> None:
if async_result.stderr and async_result.interface.recv_stderr_ready():
result.read_stderr(src=async_result.stderr, log=self.logger, verbose=verbose)

@threaded.threadpooled
@threaded.threadpooled # type: ignore
def poll_pipes() -> None:
"""Polling task for FIFO buffers."""
while not async_result.interface.status_event.is_set():
Expand Down Expand Up @@ -872,7 +736,7 @@ def execute_together(
.. versionchanged:: 4.0.0 Expose stdin and log_mask_re as optional keyword-only arguments
"""

@threaded.threadpooled
@threaded.threadpooled # type: ignore
def get_result(remote: "SSHClientBase") -> exec_result.ExecResult:
"""Get result from remote call.
Expand Down
17 changes: 17 additions & 0 deletions exec_helpers/api.py
Expand Up @@ -520,6 +520,23 @@ def check_stderr(
stdin=stdin,
**kwargs,
)
return self._handle_stderr(
result=result,
error_info=error_info,
raise_on_err=raise_on_err,
expected=expected,
exception_class=exception_class,
)

def _handle_stderr(
self,
result: exec_result.ExecResult,
error_info: typing.Optional[str],
raise_on_err: bool,
expected: typing.Iterable[typing.Union[int, proc_enums.ExitCodes]],
exception_class: "typing.Type[exceptions.CalledProcessError]",
) -> exec_result.ExecResult:
"""Internal check_stderr logic (synchronous)."""
append: str = error_info + "\n" if error_info else ""
if result.stderr:
message = (
Expand Down
17 changes: 7 additions & 10 deletions exec_helpers/async_api/api.py
Expand Up @@ -379,13 +379,10 @@ async def check_stderr( # type: ignore
stdin=stdin,
**kwargs,
)
append: str = error_info + "\n" if error_info else ""
if result.stderr:
message = (
f"{append}Command {result.cmd!r} output contains STDERR while not expected\n"
f"\texit code: {result.exit_code!s}"
)
self.logger.error(msg=message)
if raise_on_err:
raise exception_class(result=result, expected=expected)
return result
return self._handle_stderr(
result=result,
error_info=error_info,
raise_on_err=raise_on_err,
expected=expected,
exception_class=exception_class,
)
4 changes: 2 additions & 2 deletions exec_helpers/subprocess_runner.py
Expand Up @@ -126,12 +126,12 @@ def _exec_command( # type: ignore
.. versionadded:: 1.2.0
"""

@threaded.threadpooled
@threaded.threadpooled # type: ignore
def poll_stdout() -> None:
"""Sync stdout poll."""
result.read_stdout(src=async_result.stdout, log=self.logger, verbose=verbose)

@threaded.threadpooled
@threaded.threadpooled # type: ignore
def poll_stderr() -> None:
"""Sync stderr poll."""
result.read_stderr(src=async_result.stderr, log=self.logger, verbose=verbose)
Expand Down

0 comments on commit b9f469c

Please sign in to comment.