From c7b1c34781a28458d0c842ede653c9382e69bf27 Mon Sep 17 00:00:00 2001 From: Alexey Stepanov Date: Tue, 6 Nov 2018 17:04:47 +0100 Subject: [PATCH] Fix ssh stdin processing: do not add newline, support on proxy Signed-off-by: Alexey Stepanov (cherry picked from commit 0786fc95ac13b797da57c6218db394ad5609a826) Signed-off-by: Alexey Stepanov --- exec_helpers/__init__.py | 2 +- exec_helpers/_ssh_client_base.py | 29 ++++++++++++++-------- exec_helpers/api.py | 28 +++++++++++++++++---- exec_helpers/exec_result.py | 4 +-- exec_helpers/subprocess_runner.py | 7 ++---- test/test_ssh_client_execute.py | 2 +- test/test_ssh_client_execute_throw_host.py | 21 ++++++++++++++++ 7 files changed, 68 insertions(+), 25 deletions(-) diff --git a/exec_helpers/__init__.py b/exec_helpers/__init__.py index 9146db1..cfaf6be 100644 --- a/exec_helpers/__init__.py +++ b/exec_helpers/__init__.py @@ -49,7 +49,7 @@ "ExecResult", ) -__version__ = "2.9.1" +__version__ = "2.9.2" __author__ = "Alexey Stepanov" __author_email__ = "penguinolog@gmail.com" __maintainers__ = { diff --git a/exec_helpers/_ssh_client_base.py b/exec_helpers/_ssh_client_base.py index 066b15b..1e83378 100644 --- a/exec_helpers/_ssh_client_base.py +++ b/exec_helpers/_ssh_client_base.py @@ -618,14 +618,9 @@ def execute_async( if stdin is not None: if not _stdin.channel.closed: - if isinstance(stdin, bytes): - stdin_str = stdin.decode("utf-8") - elif isinstance(stdin, bytearray): - stdin_str = bytes(stdin).decode("utf-8") - else: - stdin_str = stdin - - _stdin.write("{stdin}\n".format(stdin=stdin_str).encode("utf-8")) + stdin_str = self._string_bytes_bytearray_as_bytes(stdin) + + _stdin.write(stdin_str) _stdin.flush() else: self.logger.warning("STDIN Send failed: closed channel") @@ -773,12 +768,23 @@ def execute_through_host( ) # Make proxy objects for read - stdout = channel.makefile("rb") - stderr = channel.makefile_stderr("rb") + _stdin = channel.makefile("wb") # type: paramiko.ChannelFile + stdout = channel.makefile("rb") # type: paramiko.ChannelFile + stderr = channel.makefile_stderr("rb") # type: paramiko.ChannelFile channel.exec_command(command) # nosec # Sanitize on caller side - async_result = SshExecuteAsyncResult(interface=channel, stdin=None, stdout=stdout, stderr=stderr) + stdin = kwargs.get("stdin", None) + if stdin is not None: + if not _stdin.channel.closed: + stdin_str = self._string_bytes_bytearray_as_bytes(stdin) + + _stdin.write(stdin_str) + _stdin.flush() + else: + self.logger.warning("STDIN Send failed: closed channel") + + async_result = SshExecuteAsyncResult(interface=channel, stdin=_stdin, stdout=stdout, stderr=stderr) # noinspection PyDictCreation result = self._exec_command( @@ -787,6 +793,7 @@ def execute_through_host( timeout=timeout, verbose=verbose, log_mask_re=kwargs.get("log_mask_re", None), + stdin=stdin, ) intermediate_channel.close() diff --git a/exec_helpers/api.py b/exec_helpers/api.py index d65ff7d..7eab339 100644 --- a/exec_helpers/api.py +++ b/exec_helpers/api.py @@ -59,7 +59,7 @@ def __init__(self, logger: logging.Logger, log_mask_re: typing.Optional[str] = N :type log_mask_re: typing.Optional[str] .. versionchanged:: 1.2.0 log_mask_re regex rule for masking cmd - .. versionchanged:: 1.3.5 make API public paramikoto use as interface + .. versionchanged:: 1.3.5 make API public to use as interface """ self.__lock = threading.RLock() self.__logger = logger @@ -279,18 +279,18 @@ def check_call( .. versionchanged:: 1.2.0 default timeout 1 hour """ - expected = proc_enums.exit_codes_to_enums(expected) + expected_codes = proc_enums.exit_codes_to_enums(expected) ret = self.execute(command, verbose, timeout, **kwargs) - if ret.exit_code not in expected: + if ret.exit_code not in expected_codes: message = ( "{append}Command {result.cmd!r} returned exit code " "{result.exit_code!s} while expected {expected!s}".format( - append=error_info + "\n" if error_info else "", result=ret, expected=expected + append=error_info + "\n" if error_info else "", result=ret, expected=expected_codes ) ) self.logger.error(msg=message) if raise_on_err: - raise exceptions.CalledProcessError(result=ret, expected=expected) + raise exceptions.CalledProcessError(result=ret, expected=expected_codes) return ret def check_stderr( @@ -335,3 +335,21 @@ def check_stderr( if raise_on_err: raise exceptions.CalledProcessError(result=ret, expected=kwargs.get("expected")) return ret + + @staticmethod + def _string_bytes_bytearray_as_bytes(src: typing.Union[str, bytes, bytearray]) -> bytes: + """Get bytes string from string/bytes/bytearray union. + + :return: Byte string + :rtype: bytes + :raises TypeError: unexpected source type. + """ + if isinstance(src, bytes): + return src + if isinstance(src, bytearray): + return bytes(src) + if isinstance(src, str): + return src.encode("utf-8") + raise TypeError( # pragma: no cover + "{!r} has unexpected type: not conform to Union[str, bytes, bytearray]".format(src) + ) diff --git a/exec_helpers/exec_result.py b/exec_helpers/exec_result.py index 70357c5..c4960fe 100644 --- a/exec_helpers/exec_result.py +++ b/exec_helpers/exec_result.py @@ -376,9 +376,9 @@ def __deserialize(self, fmt: str) -> typing.Any: :raises DeserializeValueError: Not valid source format """ try: - if fmt == "json": # pylint: disable=no-else-return + if fmt == "json": return json.loads(self.stdout_str, encoding="utf-8") - elif fmt == "yaml": + if fmt == "yaml": return yaml.safe_load(self.stdout_str) except Exception as e: tmpl = "{{self.cmd}} stdout is not valid {fmt}:\n" "{{stdout!r}}\n".format(fmt=fmt) diff --git a/exec_helpers/subprocess_runner.py b/exec_helpers/subprocess_runner.py index 475dfb4..1b79e80 100644 --- a/exec_helpers/subprocess_runner.py +++ b/exec_helpers/subprocess_runner.py @@ -236,12 +236,9 @@ def execute_async( if stdin is None: process_stdin = process.stdin else: - if isinstance(stdin, str): - stdin = stdin.encode(encoding="utf-8") - elif isinstance(stdin, bytearray): - stdin = bytes(stdin) + stdin_str = self._string_bytes_bytearray_as_bytes(stdin) try: - process.stdin.write(stdin) + process.stdin.write(stdin_str) except OSError as exc: if exc.errno == errno.EINVAL: # bpo-19612, bpo-30418: On Windows, stdin.write() fails diff --git a/test/test_ssh_client_execute.py b/test/test_ssh_client_execute.py index 03248c7..c60c4f6 100644 --- a/test/test_ssh_client_execute.py +++ b/test/test_ssh_client_execute.py @@ -339,7 +339,7 @@ def test_001_execute_async(ssh, paramiko_ssh_client, ssh_transport_channel, chan assert res.stdin.channel == res.interface if stdin: - res.stdin.write.assert_called_with("{stdin}\n".format(stdin=stdin).encode("utf-8")) + res.stdin.write.assert_called_with(stdin.encode("utf-8")) res.stdin.flush.assert_called_once() log = get_logger(ssh.__class__.__name__).getChild("{host}:{port}".format(host=host, port=port)) log.log.assert_called_once_with(level=logging.DEBUG, msg=command_log) diff --git a/test/test_ssh_client_execute_throw_host.py b/test/test_ssh_client_execute_throw_host.py index 62b59cb..3416a13 100644 --- a/test/test_ssh_client_execute_throw_host.py +++ b/test/test_ssh_client_execute_throw_host.py @@ -181,3 +181,24 @@ def test_03_execute_get_pty(ssh, ssh_transport_channel) -> None: target = "127.0.0.2" ssh.execute_through_host(target, command, get_pty=True) ssh_transport_channel.get_pty.assert_called_with(term="vt100", width=80, height=24, width_pixels=0, height_pixels=0) + + +def test_04_execute_use_stdin(ssh, chan_makefile) -> None: + target = "127.0.0.2" + cmd = 'read line; echo "$line"' + stdin = "test" + res = ssh.execute_through_host(target, cmd, stdin=stdin, get_pty=True) + assert res.stdin == stdin + chan_makefile.stdin.write.assert_called_once_with(stdin.encode("utf-8")) + chan_makefile.stdin.flush.assert_called_once() + + +def test_05_execute_closed_stdin(ssh, ssh_transport_channel, get_logger) -> None: + target = "127.0.0.2" + cmd = 'read line; echo "$line"' + stdin = "test" + ssh_transport_channel.closed = True + + ssh.execute_through_host(target, cmd, stdin=stdin, get_pty=True) + log = get_logger(ssh.__class__.__name__).getChild("{host}:{port}".format(host=host, port=port)) + log.warning.assert_called_once_with("STDIN Send failed: closed channel")