Skip to content

Commit

Permalink
Fix ssh stdin processing: do not add newline, support on proxy
Browse files Browse the repository at this point in the history
(cherry picked from commit 0786fc9)
Signed-off-by: Alexey Stepanov <penguinolog@gmail.com>
  • Loading branch information
penguinolog committed Nov 6, 2018
1 parent 9eefbf6 commit a131498
Show file tree
Hide file tree
Showing 7 changed files with 67 additions and 24 deletions.
2 changes: 1 addition & 1 deletion exec_helpers/__init__.py
Expand Up @@ -51,7 +51,7 @@
"ExecResult",
)

__version__ = "1.9.2"
__version__ = "1.9.3"
__author__ = "Alexey Stepanov"
__author_email__ = "penguinolog@gmail.com"
__maintainers__ = {
Expand Down
29 changes: 18 additions & 11 deletions exec_helpers/_ssh_client_base.py
Expand Up @@ -598,14 +598,9 @@ def execute_async(

if stdin is not None:
if not _stdin.channel.closed:
if isinstance(stdin, bytes):
stdin_str = stdin.decode("utf-8")
elif isinstance(stdin, bytearray):
stdin_str = bytes(stdin).decode("utf-8")
else:
stdin_str = stdin

_stdin.write("{stdin}\n".format(stdin=stdin_str).encode("utf-8"))
stdin_str = self._string_bytes_bytearray_as_bytes(stdin)

_stdin.write(stdin_str)
_stdin.flush()
else:
self.logger.warning("STDIN Send failed: closed channel")
Expand Down Expand Up @@ -752,12 +747,23 @@ def execute_through_host(
)

# Make proxy objects for read
stdout = channel.makefile("rb")
stderr = channel.makefile_stderr("rb")
_stdin = channel.makefile("wb") # type: paramiko.ChannelFile
stdout = channel.makefile("rb") # type: paramiko.ChannelFile
stderr = channel.makefile_stderr("rb") # type: paramiko.ChannelFile

channel.exec_command(command) # nosec # Sanitize on caller side

async_result = SshExecuteAsyncResult(interface=channel, stdin=None, stdout=stdout, stderr=stderr)
stdin = kwargs.get("stdin", None)
if stdin is not None:
if not _stdin.channel.closed:
stdin_str = self._string_bytes_bytearray_as_bytes(stdin)

_stdin.write(stdin_str)
_stdin.flush()
else:
self.logger.warning("STDIN Send failed: closed channel")

async_result = SshExecuteAsyncResult(interface=channel, stdin=_stdin, stdout=stdout, stderr=stderr)

# noinspection PyDictCreation
result = self._exec_command(
Expand All @@ -766,6 +772,7 @@ def execute_through_host(
timeout=timeout,
verbose=verbose,
log_mask_re=kwargs.get("log_mask_re", None),
stdin=stdin,
)

intermediate_channel.close()
Expand Down
28 changes: 23 additions & 5 deletions exec_helpers/api.py
Expand Up @@ -63,7 +63,7 @@ def __init__(self, logger, log_mask_re=None): # type: (logging.Logger, typing.O
:type log_mask_re: typing.Optional[str]
.. versionchanged:: 1.2.0 log_mask_re regex rule for masking cmd
.. versionchanged:: 1.3.5 make API public paramikoto use as interface
.. versionchanged:: 1.3.5 make API public to use as interface
"""
self.__lock = threading.RLock()
self.__logger = logger
Expand Down Expand Up @@ -281,18 +281,18 @@ def check_call(
.. versionchanged:: 1.2.0 default timeout 1 hour
"""
expected = proc_enums.exit_codes_to_enums(expected)
expected_codes = proc_enums.exit_codes_to_enums(expected)
ret = self.execute(command, verbose, timeout, **kwargs)
if ret.exit_code not in expected:
if ret.exit_code not in expected_codes:
message = (
"{append}Command {result.cmd!r} returned exit code "
"{result.exit_code!s} while expected {expected!s}".format(
append=error_info + "\n" if error_info else "", result=ret, expected=expected
append=error_info + "\n" if error_info else "", result=ret, expected=expected_codes
)
)
self.logger.error(msg=message)
if raise_on_err:
raise exceptions.CalledProcessError(result=ret, expected=expected)
raise exceptions.CalledProcessError(result=ret, expected=expected_codes)
return ret

def check_stderr(
Expand Down Expand Up @@ -337,3 +337,21 @@ def check_stderr(
if raise_on_err:
raise exceptions.CalledProcessError(result=ret, expected=kwargs.get("expected"))
return ret

@staticmethod
def _string_bytes_bytearray_as_bytes(src): # type: (typing.Union[six.text_type, bytes, bytearray]) -> bytes
"""Get bytes string from string/bytes/bytearray union.
:return: Byte string
:rtype: bytes
:raises TypeError: unexpected source type.
"""
if isinstance(src, bytes):
return src
if isinstance(src, bytearray):
return bytes(src)
if isinstance(src, six.text_type):
return src.encode("utf-8")
raise TypeError( # pragma: no cover
"{!r} has unexpected type: not conform to Union[str, bytes, bytearray]".format(src)
)
2 changes: 1 addition & 1 deletion exec_helpers/exec_result.py
Expand Up @@ -385,7 +385,7 @@ def __deserialize(self, fmt): # type: (str) -> typing.Any
try:
if fmt == "json":
return json.loads(self.stdout_str, encoding="utf-8")
elif fmt == "yaml":
if fmt == "yaml":
return yaml.safe_load(self.stdout_str)
except Exception:
tmpl = " stdout is not valid {fmt}:\n" "{{stdout!r}}\n".format(fmt=fmt)
Expand Down
7 changes: 2 additions & 5 deletions exec_helpers/subprocess_runner.py
Expand Up @@ -225,12 +225,9 @@ def execute_async(
if stdin is None:
process_stdin = process.stdin
else:
if isinstance(stdin, six.text_type):
stdin = stdin.encode(encoding="utf-8")
elif isinstance(stdin, bytearray):
stdin = bytes(stdin)
stdin_str = self._string_bytes_bytearray_as_bytes(stdin)
try:
process.stdin.write(stdin)
process.stdin.write(stdin_str)
except OSError as exc:
if exc.errno == errno.EINVAL:
# bpo-19612, bpo-30418: On Windows, stdin.write() fails
Expand Down
2 changes: 1 addition & 1 deletion test/test_ssh_client_execute.py
Expand Up @@ -342,7 +342,7 @@ def test_001_execute_async(ssh, paramiko_ssh_client, ssh_transport_channel, chan
assert res.stdin.channel == res.interface

if stdin:
res.stdin.write.assert_called_with("{stdin}\n".format(stdin=stdin).encode("utf-8"))
res.stdin.write.assert_called_with(stdin.encode("utf-8"))
res.stdin.flush.assert_called_once()
log = get_logger(ssh.__class__.__name__).getChild("{host}:{port}".format(host=host, port=port))
log.log.assert_called_once_with(level=logging.DEBUG, msg=command_log)
Expand Down
21 changes: 21 additions & 0 deletions test/test_ssh_client_execute_throw_host.py
Expand Up @@ -186,3 +186,24 @@ def test_03_execute_get_pty(ssh, ssh_transport_channel):
target = "127.0.0.2"
ssh.execute_through_host(target, command, get_pty=True)
ssh_transport_channel.get_pty.assert_called_with(term="vt100", width=80, height=24, width_pixels=0, height_pixels=0)


def test_04_execute_use_stdin(ssh, chan_makefile):
target = "127.0.0.2"
cmd = 'read line; echo "$line"'
stdin = "test"
res = ssh.execute_through_host(target, cmd, stdin=stdin, get_pty=True)
assert res.stdin == stdin
chan_makefile.stdin.write.assert_called_once_with(stdin.encode("utf-8"))
chan_makefile.stdin.flush.assert_called_once()


def test_05_execute_closed_stdin(ssh, ssh_transport_channel, get_logger):
target = "127.0.0.2"
cmd = 'read line; echo "$line"'
stdin = "test"
ssh_transport_channel.closed = True

ssh.execute_through_host(target, cmd, stdin=stdin, get_pty=True)
log = get_logger(ssh.__class__.__name__).getChild("{host}:{port}".format(host=host, port=port))
log.warning.assert_called_once_with("STDIN Send failed: closed channel")

0 comments on commit a131498

Please sign in to comment.