diff --git a/exec_helpers/__init__.py b/exec_helpers/__init__.py index 424e9bd..0eccbbe 100644 --- a/exec_helpers/__init__.py +++ b/exec_helpers/__init__.py @@ -51,7 +51,7 @@ "async_api", ) -__version__ = "3.1.1" +__version__ = "3.1.2" __author__ = "Alexey Stepanov" __author_email__ = "penguinolog@gmail.com" __maintainers__ = { diff --git a/exec_helpers/_ssh_client_base.py b/exec_helpers/_ssh_client_base.py index e2f52c5..066b15b 100644 --- a/exec_helpers/_ssh_client_base.py +++ b/exec_helpers/_ssh_client_base.py @@ -27,7 +27,6 @@ import platform import stat import sys -import threading import time import typing import warnings @@ -619,7 +618,14 @@ def execute_async( if stdin is not None: if not _stdin.channel.closed: - _stdin.write("{stdin}\n".format(stdin=stdin)) + if isinstance(stdin, bytes): + stdin_str = stdin.decode("utf-8") + elif isinstance(stdin, bytearray): + stdin_str = bytes(stdin).decode("utf-8") + else: + stdin_str = stdin + + _stdin.write("{stdin}\n".format(stdin=stdin_str).encode("utf-8")) _stdin.flush() else: self.logger.warning("STDIN Send failed: closed channel") @@ -665,22 +671,16 @@ def poll_streams() -> None: result.read_stderr(src=async_result.stderr, log=self.logger, verbose=verbose) @threaded.threadpooled - def poll_pipes(stop: threading.Event) -> None: - """Polling task for FIFO buffers. - - :type stop: Event - """ - while not stop.is_set(): + def poll_pipes() -> None: + """Polling task for FIFO buffers.""" + while not async_result.interface.status_event.is_set(): time.sleep(0.1) if async_result.stdout or async_result.stderr: poll_streams() - if async_result.interface.status_event.is_set(): - result.read_stdout(src=async_result.stdout, log=self.logger, verbose=verbose) - result.read_stderr(src=async_result.stderr, log=self.logger, verbose=verbose) - result.exit_code = async_result.interface.exit_status - - stop.set() + result.read_stdout(src=async_result.stdout, log=self.logger, verbose=verbose) + result.read_stderr(src=async_result.stderr, log=self.logger, verbose=verbose) + result.exit_code = async_result.interface.exit_status # channel.status_event.wait(timeout) cmd_for_log = self._mask_command(cmd=command, log_mask_re=log_mask_re) @@ -688,22 +688,20 @@ def poll_pipes(stop: threading.Event) -> None: # Store command with hidden data result = exec_result.ExecResult(cmd=cmd_for_log, stdin=kwargs.get("stdin")) - stop_event = threading.Event() - # pylint: disable=assignment-from-no-return # noinspection PyNoneFunctionAssignment - future = poll_pipes(stop=stop_event) # type: concurrent.futures.Future + future = poll_pipes() # type: concurrent.futures.Future # pylint: enable=assignment-from-no-return concurrent.futures.wait([future], timeout) # Process closed? - if stop_event.is_set(): + if async_result.interface.status_event.is_set(): async_result.interface.close() return result - stop_event.set() async_result.interface.close() + async_result.interface.status_event.set() future.cancel() wait_err_msg = _log_templates.CMD_WAIT_ERROR.format(result=result, timeout=timeout) @@ -840,7 +838,7 @@ def get_result(remote: "SSHClientBase") -> exec_result.ExecResult: cmd_for_log = remote._mask_command(cmd=command, log_mask_re=kwargs.get("log_mask_re", None)) # pylint: enable=protected-access - res = exec_result.ExecResult(cmd=cmd_for_log) + res = exec_result.ExecResult(cmd=cmd_for_log, stdin=kwargs.get("stdin", None)) res.read_stdout(src=async_result.stdout) res.read_stderr(src=async_result.stderr) res.exit_code = exit_code diff --git a/exec_helpers/api.py b/exec_helpers/api.py index 7ec86a6..d65ff7d 100644 --- a/exec_helpers/api.py +++ b/exec_helpers/api.py @@ -86,7 +86,7 @@ def __enter__(self) -> "ExecHelper": self.lock.acquire() return self - def __exit__(self, exc_type: typing.Any, exc_val: typing.Any, exc_tb: typing.Any) -> None: # pragma: no cover + def __exit__(self, exc_type: typing.Any, exc_val: typing.Any, exc_tb: typing.Any) -> None: """Context manager usage.""" self.lock.release() diff --git a/test/test_sftp.py b/test/test_sftp.py new file mode 100644 index 0000000..0d3f32f --- /dev/null +++ b/test/test_sftp.py @@ -0,0 +1,287 @@ +# Copyright 2018 Alexey Stepanov aka penguinolog. + +# Copyright 2016 Mirantis, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +# pylint: disable=no-self-use + +import os +import posixpath +import stat +import unittest + +import mock +import paramiko + +import exec_helpers + + +host = "127.0.0.1" +port = 22 +username = "user" +password = "pass" + + +@mock.patch("logging.getLogger", autospec=True) +@mock.patch("paramiko.AutoAddPolicy", autospec=True, return_value="AutoAddPolicy") +@mock.patch("paramiko.SSHClient", autospec=True) +class TestSftp(unittest.TestCase): + def tearDown(self): + with mock.patch("warnings.warn"): + exec_helpers.SSHClient._clear_cache() + + @staticmethod + def prepare_sftp_file_tests(client): + _ssh = mock.Mock() + client.return_value = _ssh + _sftp = mock.Mock() + open_sftp = mock.Mock(parent=_ssh, return_value=_sftp) + _ssh.attach_mock(open_sftp, "open_sftp") + + # noinspection PyTypeChecker + ssh = exec_helpers.SSHClient( + host=host, port=port, auth=exec_helpers.SSHAuth(username=username, password=password) + ) + return ssh, _sftp + + def test_exists(self, client, *args): + ssh, _sftp = self.prepare_sftp_file_tests(client) + lstat = mock.Mock() + _sftp.attach_mock(lstat, "lstat") + dst = "/etc" + + # noinspection PyTypeChecker + result = ssh.exists(dst) + self.assertTrue(result) + lstat.assert_called_once_with(dst) + + # Negative scenario + lstat.reset_mock() + lstat.side_effect = IOError + + # noinspection PyTypeChecker + result = ssh.exists(dst) + self.assertFalse(result) + lstat.assert_called_once_with(dst) + + def test_stat(self, client, *args): + ssh, _sftp = self.prepare_sftp_file_tests(client) + stat = mock.Mock() + _sftp.attach_mock(stat, "stat") + stat.return_value = paramiko.sftp_attr.SFTPAttributes() + stat.return_value.st_size = 0 + stat.return_value.st_uid = 0 + stat.return_value.st_gid = 0 + dst = "/etc/passwd" + + # noinspection PyTypeChecker + result = ssh.stat(dst) + self.assertEqual(result.st_size, 0) + self.assertEqual(result.st_uid, 0) + self.assertEqual(result.st_gid, 0) + + def test_isfile(self, client, *args): + class Attrs: + def __init__(self, mode): + self.st_mode = mode + + ssh, _sftp = self.prepare_sftp_file_tests(client) + lstat = mock.Mock() + _sftp.attach_mock(lstat, "lstat") + lstat.return_value = Attrs(stat.S_IFREG) + dst = "/etc/passwd" + + # noinspection PyTypeChecker + result = ssh.isfile(dst) + self.assertTrue(result) + lstat.assert_called_once_with(dst) + + # Negative scenario + lstat.reset_mock() + lstat.return_value = Attrs(stat.S_IFDIR) + + # noinspection PyTypeChecker + result = ssh.isfile(dst) + self.assertFalse(result) + lstat.assert_called_once_with(dst) + + lstat.reset_mock() + lstat.side_effect = IOError + + # noinspection PyTypeChecker + result = ssh.isfile(dst) + self.assertFalse(result) + lstat.assert_called_once_with(dst) + + def test_isdir(self, client, *args): + class Attrs: + def __init__(self, mode): + self.st_mode = mode + + ssh, _sftp = self.prepare_sftp_file_tests(client) + lstat = mock.Mock() + _sftp.attach_mock(lstat, "lstat") + lstat.return_value = Attrs(stat.S_IFDIR) + dst = "/etc/passwd" + + # noinspection PyTypeChecker + result = ssh.isdir(dst) + self.assertTrue(result) + lstat.assert_called_once_with(dst) + + # Negative scenario + lstat.reset_mock() + lstat.return_value = Attrs(stat.S_IFREG) + + # noinspection PyTypeChecker + result = ssh.isdir(dst) + self.assertFalse(result) + lstat.assert_called_once_with(dst) + + lstat.reset_mock() + lstat.side_effect = IOError + # noinspection PyTypeChecker + result = ssh.isdir(dst) + self.assertFalse(result) + lstat.assert_called_once_with(dst) + + @mock.patch("exec_helpers.ssh_client.SSHClient.exists") + @mock.patch("exec_helpers.ssh_client.SSHClient.execute") + def test_mkdir(self, execute, exists, *args): + exists.side_effect = [False, True] + + dst = "~/tst dir" + escaped_dst = r"~/tst\ dir" + + # noinspection PyTypeChecker + ssh = exec_helpers.SSHClient( + host=host, port=port, auth=exec_helpers.SSHAuth(username=username, password=password) + ) + + # Path not exists + # noinspection PyTypeChecker + ssh.mkdir(dst) + exists.assert_called_once_with(dst) + execute.assert_called_once_with("mkdir -p {}\n".format(escaped_dst)) + + # Path exists + exists.reset_mock() + execute.reset_mock() + + # noinspection PyTypeChecker + ssh.mkdir(dst) + exists.assert_called_once_with(dst) + execute.assert_not_called() + + @mock.patch("exec_helpers.ssh_client.SSHClient.execute") + def test_rm_rf(self, execute, *args): + dst = "~/tst" + + # noinspection PyTypeChecker + ssh = exec_helpers.SSHClient( + host=host, port=port, auth=exec_helpers.SSHAuth(username=username, password=password) + ) + + # Path not exists + # noinspection PyTypeChecker + ssh.rm_rf(dst) + execute.assert_called_once_with("rm -rf {}".format(dst)) + + def test_open(self, client, *args): + ssh, _sftp = self.prepare_sftp_file_tests(client) + fopen = mock.Mock(return_value=True) + _sftp.attach_mock(fopen, "open") + + dst = "/etc/passwd" + mode = "r" + # noinspection PyTypeChecker + result = ssh.open(dst) + fopen.assert_called_once_with(dst, mode) + self.assertTrue(result) + + @mock.patch("exec_helpers.ssh_client.logger", autospec=True) + @mock.patch("exec_helpers.ssh_client.SSHClient.exists") + @mock.patch("os.path.exists", autospec=True) + @mock.patch("exec_helpers.ssh_client.SSHClient.isdir") + @mock.patch("os.path.isdir", autospec=True) + def test_download(self, isdir, remote_isdir, exists, remote_exists, logger, client, policy, _logger): + ssh, _sftp = self.prepare_sftp_file_tests(client) + isdir.return_value = True + exists.side_effect = [True, False, False] + remote_isdir.side_effect = [False, False, True] + remote_exists.side_effect = [True, False, False] + + dst = "/etc/environment" + target = "/tmp/environment" + # noinspection PyTypeChecker + result = ssh.download(destination=dst, target=target) + self.assertTrue(result) + isdir.assert_called_once_with(target) + exists.assert_called_once_with(posixpath.join(target, os.path.basename(dst))) + remote_isdir.assert_called_once_with(dst) + remote_exists.assert_called_once_with(dst) + _sftp.assert_has_calls((mock.call.get(dst, posixpath.join(target, os.path.basename(dst))),)) + + # Negative scenarios + # noinspection PyTypeChecker + result = ssh.download(destination=dst, target=target) + self.assertFalse(result) + + # noinspection PyTypeChecker + ssh.download(destination=dst, target=target) + + @mock.patch("exec_helpers.ssh_client.SSHClient.isdir") + @mock.patch("os.path.isdir", autospec=True) + def test_upload_file(self, isdir, remote_isdir, client, *args): + ssh, _sftp = self.prepare_sftp_file_tests(client) + isdir.return_value = False + remote_isdir.return_value = False + target = "/etc/environment" + source = "/tmp/environment" + + # noinspection PyTypeChecker + ssh.upload(source=source, target=target) + isdir.assert_called_once_with(source) + remote_isdir.assert_called_once_with(target) + _sftp.assert_has_calls((mock.call.put(source, target),)) + + @mock.patch("exec_helpers.ssh_client.SSHClient.exists") + @mock.patch("exec_helpers.ssh_client.SSHClient.mkdir") + @mock.patch("os.walk") + @mock.patch("exec_helpers.ssh_client.SSHClient.isdir") + @mock.patch("os.path.isdir", autospec=True) + def test_upload_dir(self, isdir, remote_isdir, walk, mkdir, exists, client, *args): + ssh, _sftp = self.prepare_sftp_file_tests(client) + isdir.return_value = True + remote_isdir.return_value = True + exists.return_value = True + target = "/etc" + source = "/tmp/bash" + filename = "bashrc" + walk.return_value = ((source, "", [filename]),) + expected_path = posixpath.join(target, os.path.basename(source)) + expected_file = posixpath.join(expected_path, filename) + + # noinspection PyTypeChecker + ssh.upload(source=source, target=target) + isdir.assert_called_once_with(source) + remote_isdir.assert_called_once_with(target) + mkdir.assert_called_once_with(expected_path) + exists.assert_called_once_with(expected_file) + _sftp.assert_has_calls( + ( + mock.call.unlink(expected_file), + mock.call.put(os.path.normpath(os.path.join(source, filename)), expected_file), + ) + ) diff --git a/test/test_ssh_client.py b/test/test_ssh_client.py deleted file mode 100644 index d0514df..0000000 --- a/test/test_ssh_client.py +++ /dev/null @@ -1,1479 +0,0 @@ -# Copyright 2018 Alexey Stepanov aka penguinolog. - -# Copyright 2016 Mirantis, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -# pylint: disable=no-self-use - -import base64 -import logging -import os -import posixpath -import stat -import unittest - -import mock -import paramiko - -import exec_helpers -from exec_helpers import constants -from exec_helpers import exec_result - - -class FakeStream: - def __init__(self, *args): - self.__src = list(args) - - def __iter__(self): - if len(self.__src) == 0: - raise IOError() - for _ in range(len(self.__src)): - yield self.__src.pop(0) - - -host = "127.0.0.1" -port = 22 -username = "user" -password = "pass" -command = "ls ~\nline 2\nline 3\nline с кирилицей" -command_log = "Executing command:\n{!r}\n".format(command.rstrip()) -stdout_list = [b" \n", b"2\n", b"3\n", b" \n"] -stdout_str = b"".join(stdout_list).strip().decode("utf-8") -stderr_list = [b" \n", b"0\n", b"1\n", b" \n"] -stderr_str = b"".join(stderr_list).strip().decode("utf-8") -encoded_cmd = base64.b64encode("{}\n".format(command).encode("utf-8")).decode("utf-8") -print_stdin = 'read line; echo "$line"' - - -@mock.patch("logging.getLogger", autospec=True) -@mock.patch("paramiko.AutoAddPolicy", autospec=True, return_value="AutoAddPolicy") -@mock.patch("paramiko.SSHClient", autospec=True) -class TestExecute(unittest.TestCase): - def tearDown(self): - with mock.patch("warnings.warn"): - exec_helpers.SSHClient._clear_cache() - - @staticmethod - def get_ssh(): - """SSHClient object builder for execution tests - - :rtype: exec_wrappers.SSHClient - """ - # noinspection PyTypeChecker - return exec_helpers.SSHClient( - host=host, port=port, auth=exec_helpers.SSHAuth(username=username, password=password) - ) - - @staticmethod - def gen_cmd_result_log_message(result): - return "Command {result.cmd!r} exit code: {result.exit_code!s}".format(result=result) - - def test_001_execute_async(self, client, policy, logger): - chan = mock.Mock() - open_session = mock.Mock(return_value=chan) - transport = mock.Mock() - transport.attach_mock(open_session, "open_session") - get_transport = mock.Mock(return_value=transport) - _ssh = mock.Mock() - _ssh.attach_mock(get_transport, "get_transport") - client.return_value = _ssh - - ssh = self.get_ssh() - - # noinspection PyTypeChecker - result = ssh.execute_async(command=command) - get_transport.assert_called_once() - open_session.assert_called_once() - - self.assertIn(chan, result) - chan.assert_has_calls( - ( - mock.call.makefile("wb"), - mock.call.makefile("rb"), - mock.call.makefile_stderr("rb"), - mock.call.exec_command("{}\n".format(command)), - ) - ) - # raise ValueError(logger.mock_calls) - log = logger(ssh.__class__.__name__).getChild("{host}:{port}".format(host=host, port=port)) - self.assertIn(mock.call.log(level=logging.DEBUG, msg=command_log), log.mock_calls) - - def test_002_execute_async_pty(self, client, policy, logger): - chan = mock.Mock() - open_session = mock.Mock(return_value=chan) - transport = mock.Mock() - transport.attach_mock(open_session, "open_session") - get_transport = mock.Mock(return_value=transport) - _ssh = mock.Mock() - _ssh.attach_mock(get_transport, "get_transport") - client.return_value = _ssh - - ssh = self.get_ssh() - - # noinspection PyTypeChecker - result = ssh.execute_async(command=command, get_pty=True) - get_transport.assert_called_once() - open_session.assert_called_once() - - self.assertIn(chan, result) - chan.assert_has_calls( - ( - mock.call.get_pty(term="vt100", width=80, height=24, width_pixels=0, height_pixels=0), - mock.call.makefile("wb"), - mock.call.makefile("rb"), - mock.call.makefile_stderr("rb"), - mock.call.exec_command("{}\n".format(command)), - ) - ) - log = logger(ssh.__class__.__name__).getChild("{host}:{port}".format(host=host, port=port)) - self.assertIn(mock.call.log(level=logging.DEBUG, msg=command_log), log.mock_calls) - - def test_003_execute_async_no_stdout_stderr(self, client, policy, logger): - chan = mock.Mock() - open_session = mock.Mock(return_value=chan) - transport = mock.Mock() - transport.attach_mock(open_session, "open_session") - get_transport = mock.Mock(return_value=transport) - _ssh = mock.Mock() - _ssh.attach_mock(get_transport, "get_transport") - client.return_value = _ssh - - ssh = self.get_ssh() - - # noinspection PyTypeChecker - result = ssh.execute_async(command=command, open_stdout=False) - - self.assertIn(chan, result) - chan.assert_has_calls( - (mock.call.makefile("wb"), mock.call.makefile_stderr("rb"), mock.call.exec_command("{}\n".format(command))) - ) - - chan.reset_mock() - result = ssh.execute_async(command=command, open_stderr=False) - - self.assertIn(chan, result) - chan.assert_has_calls( - (mock.call.makefile("wb"), mock.call.makefile("rb"), mock.call.exec_command("{}\n".format(command))) - ) - - chan.reset_mock() - result = ssh.execute_async(command=command, open_stdout=False, open_stderr=False) - - self.assertIn(chan, result) - chan.assert_has_calls((mock.call.makefile("wb"), mock.call.exec_command("{}\n".format(command)))) - - def test_004_execute_async_sudo(self, client, policy, logger): - chan = mock.Mock() - open_session = mock.Mock(return_value=chan) - transport = mock.Mock() - transport.attach_mock(open_session, "open_session") - get_transport = mock.Mock(return_value=transport) - _ssh = mock.Mock() - _ssh.attach_mock(get_transport, "get_transport") - client.return_value = _ssh - - ssh = self.get_ssh() - ssh.sudo_mode = True - - # noinspection PyTypeChecker - result = ssh.execute_async(command=command) - get_transport.assert_called_once() - open_session.assert_called_once() - - self.assertIn(chan, result) - chan.assert_has_calls( - ( - mock.call.makefile("wb"), - mock.call.makefile("rb"), - mock.call.makefile_stderr("rb"), - mock.call.exec_command("sudo -S bash -c '" 'eval "$(base64 -d <(echo "{0}"))"\''.format(encoded_cmd)), - ) - ) - log = logger(ssh.__class__.__name__).getChild("{host}:{port}".format(host=host, port=port)) - self.assertIn(mock.call.log(level=logging.DEBUG, msg=command_log), log.mock_calls) - - def test_005_execute_async_with_sudo_enforce(self, client, policy, logger): - chan = mock.Mock() - open_session = mock.Mock(return_value=chan) - transport = mock.Mock() - transport.attach_mock(open_session, "open_session") - get_transport = mock.Mock(return_value=transport) - _ssh = mock.Mock() - _ssh.attach_mock(get_transport, "get_transport") - client.return_value = _ssh - - ssh = self.get_ssh() - self.assertFalse(ssh.sudo_mode) - with exec_helpers.SSHClient.sudo(ssh, enforce=True): - self.assertTrue(ssh.sudo_mode) - # noinspection PyTypeChecker - result = ssh.execute_async(command=command) - self.assertFalse(ssh.sudo_mode) - - get_transport.assert_called_once() - open_session.assert_called_once() - - self.assertIn(chan, result) - chan.assert_has_calls( - ( - mock.call.makefile("wb"), - mock.call.makefile("rb"), - mock.call.makefile_stderr("rb"), - mock.call.exec_command("sudo -S bash -c '" 'eval "$(base64 -d <(echo "{0}"))"\''.format(encoded_cmd)), - ) - ) - log = logger(ssh.__class__.__name__).getChild("{host}:{port}".format(host=host, port=port)) - self.assertIn(mock.call.log(level=logging.DEBUG, msg=command_log), log.mock_calls) - - def test_006_execute_async_with_no_sudo_enforce(self, client, policy, logger): - chan = mock.Mock() - open_session = mock.Mock(return_value=chan) - transport = mock.Mock() - transport.attach_mock(open_session, "open_session") - get_transport = mock.Mock(return_value=transport) - _ssh = mock.Mock() - _ssh.attach_mock(get_transport, "get_transport") - client.return_value = _ssh - - ssh = self.get_ssh() - ssh.sudo_mode = True - - with ssh.sudo(enforce=False): - # noinspection PyTypeChecker - result = ssh.execute_async(command=command) - get_transport.assert_called_once() - open_session.assert_called_once() - - self.assertIn(chan, result) - chan.assert_has_calls( - ( - mock.call.makefile("wb"), - mock.call.makefile("rb"), - mock.call.makefile_stderr("rb"), - mock.call.exec_command("{}\n".format(command)), - ) - ) - log = logger(ssh.__class__.__name__).getChild("{host}:{port}".format(host=host, port=port)) - self.assertIn(mock.call.log(level=logging.DEBUG, msg=command_log), log.mock_calls) - - def test_007_execute_async_with_sudo_none_enforce(self, client, policy, logger): - chan = mock.Mock() - open_session = mock.Mock(return_value=chan) - transport = mock.Mock() - transport.attach_mock(open_session, "open_session") - get_transport = mock.Mock(return_value=transport) - _ssh = mock.Mock() - _ssh.attach_mock(get_transport, "get_transport") - client.return_value = _ssh - - ssh = self.get_ssh() - ssh.sudo_mode = False - - with ssh.sudo(): - # noinspection PyTypeChecker - result = ssh.execute_async(command=command) - get_transport.assert_called_once() - open_session.assert_called_once() - - self.assertIn(chan, result) - chan.assert_has_calls( - ( - mock.call.makefile("wb"), - mock.call.makefile("rb"), - mock.call.makefile_stderr("rb"), - mock.call.exec_command("{}\n".format(command)), - ) - ) - log = logger(ssh.__class__.__name__).getChild("{host}:{port}".format(host=host, port=port)) - self.assertIn(mock.call.log(level=logging.DEBUG, msg=command_log), log.mock_calls) - - @mock.patch("exec_helpers.ssh_auth.SSHAuth.enter_password") - def test_008_execute_async_sudo_password(self, enter_password, client, policy, logger): - stdin = mock.Mock(name="stdin") - stdout = mock.Mock(name="stdout") - stdout_channel = mock.Mock() - stdout_channel.configure_mock(closed=False) - stdout.attach_mock(stdout_channel, "channel") - makefile = mock.Mock(side_effect=[stdin, stdout]) - chan = mock.Mock() - chan.attach_mock(makefile, "makefile") - open_session = mock.Mock(return_value=chan) - transport = mock.Mock() - transport.attach_mock(open_session, "open_session") - get_transport = mock.Mock(return_value=transport) - _ssh = mock.Mock() - _ssh.attach_mock(get_transport, "get_transport") - client.return_value = _ssh - - ssh = self.get_ssh() - ssh.sudo_mode = True - - # noinspection PyTypeChecker - result = ssh.execute_async(command=command) - get_transport.assert_called_once() - open_session.assert_called_once() - # raise ValueError(closed.mock_calls) - enter_password.assert_called_once_with(stdin) - stdin.assert_has_calls((mock.call.flush(),)) - - self.assertIn(chan, result) - chan.assert_has_calls( - ( - mock.call.makefile("wb"), - mock.call.makefile("rb"), - mock.call.makefile_stderr("rb"), - mock.call.exec_command("sudo -S bash -c '" 'eval "$(base64 -d <(echo "{0}"))"\''.format(encoded_cmd)), - ) - ) - log = logger(ssh.__class__.__name__).getChild("{host}:{port}".format(host=host, port=port)) - self.assertIn(mock.call.log(level=logging.DEBUG, msg=command_log), log.mock_calls) - - def test_009_execute_async_verbose(self, client, policy, logger): - chan = mock.Mock() - open_session = mock.Mock(return_value=chan) - transport = mock.Mock() - transport.attach_mock(open_session, "open_session") - get_transport = mock.Mock(return_value=transport) - _ssh = mock.Mock() - _ssh.attach_mock(get_transport, "get_transport") - client.return_value = _ssh - - ssh = self.get_ssh() - - # noinspection PyTypeChecker - result = ssh.execute_async(command=command, verbose=True) - get_transport.assert_called_once() - open_session.assert_called_once() - - self.assertIn(chan, result) - chan.assert_has_calls( - ( - mock.call.makefile("wb"), - mock.call.makefile("rb"), - mock.call.makefile_stderr("rb"), - mock.call.exec_command("{}\n".format(command)), - ) - ) - log = logger(ssh.__class__.__name__).getChild("{host}:{port}".format(host=host, port=port)) - self.assertIn(mock.call.log(level=logging.INFO, msg=command_log), log.mock_calls) - - def test_010_execute_async_mask_command(self, client, policy, logger): - cmd = "USE='secret=secret_pass' do task" - log_mask_re = r"secret\s*=\s*([A-Z-a-z0-9_\-]+)" - masked_cmd = "USE='secret=<*masked*>' do task" - cmd_log = "Executing command:\n{!r}\n".format(masked_cmd) - - chan = mock.Mock() - open_session = mock.Mock(return_value=chan) - transport = mock.Mock() - transport.attach_mock(open_session, "open_session") - get_transport = mock.Mock(return_value=transport) - _ssh = mock.Mock() - _ssh.attach_mock(get_transport, "get_transport") - client.return_value = _ssh - - ssh = self.get_ssh() - - # noinspection PyTypeChecker - result = ssh.execute_async(command=cmd, log_mask_re=log_mask_re) - get_transport.assert_called_once() - open_session.assert_called_once() - - self.assertIn(chan, result) - chan.assert_has_calls( - ( - mock.call.makefile("wb"), - mock.call.makefile("rb"), - mock.call.makefile_stderr("rb"), - mock.call.exec_command("{}\n".format(cmd)), - ) - ) - log = logger(ssh.__class__.__name__).getChild("{host}:{port}".format(host=host, port=port)) - self.assertIn(mock.call.log(level=logging.DEBUG, msg=cmd_log), log.mock_calls) - - def test_011_check_stdin_str(self, client, policy, logger): - stdin_val = "this is a line" - - stdin = mock.Mock(name="stdin") - stdin_channel = mock.Mock() - stdin_channel.configure_mock(closed=False) - stdin.attach_mock(stdin_channel, "channel") - - stdout = mock.Mock(name="stdout") - stdout_channel = mock.Mock() - stdout_channel.configure_mock(closed=False) - stdout.attach_mock(stdout_channel, "channel") - - chan = mock.Mock() - chan.attach_mock(mock.Mock(side_effect=[stdin, stdout]), "makefile") - - open_session = mock.Mock(return_value=chan) - - transport = mock.Mock() - transport.attach_mock(open_session, "open_session") - get_transport = mock.Mock(return_value=transport) - _ssh = mock.Mock() - _ssh.attach_mock(get_transport, "get_transport") - client.return_value = _ssh - - ssh = self.get_ssh() - - # noinspection PyTypeChecker - result = ssh.execute_async(command=print_stdin, stdin=stdin_val) - - get_transport.assert_called_once() - open_session.assert_called_once() - stdin.assert_has_calls([mock.call.write("{val}\n".format(val=stdin_val)), mock.call.flush()]) - - self.assertIn(chan, result) - chan.assert_has_calls( - ( - mock.call.makefile("wb"), - mock.call.makefile("rb"), - mock.call.makefile_stderr("rb"), - mock.call.exec_command("{val}\n".format(val=print_stdin)), - ) - ) - - def test_012_check_stdin_bytes(self, client, policy, logger): - stdin_val = b"this is a line" - - stdin = mock.Mock(name="stdin") - stdin_channel = mock.Mock() - stdin_channel.configure_mock(closed=False) - stdin.attach_mock(stdin_channel, "channel") - - stdout = mock.Mock(name="stdout") - stdout_channel = mock.Mock() - stdout_channel.configure_mock(closed=False) - stdout.attach_mock(stdout_channel, "channel") - - chan = mock.Mock() - chan.attach_mock(mock.Mock(side_effect=[stdin, stdout]), "makefile") - - open_session = mock.Mock(return_value=chan) - - transport = mock.Mock() - transport.attach_mock(open_session, "open_session") - get_transport = mock.Mock(return_value=transport) - _ssh = mock.Mock() - _ssh.attach_mock(get_transport, "get_transport") - client.return_value = _ssh - - ssh = self.get_ssh() - - # noinspection PyTypeChecker - result = ssh.execute_async(command=print_stdin, stdin=stdin_val) - - get_transport.assert_called_once() - open_session.assert_called_once() - stdin.assert_has_calls([mock.call.write("{val}\n".format(val=stdin_val)), mock.call.flush()]) - - self.assertIn(chan, result) - chan.assert_has_calls( - ( - mock.call.makefile("wb"), - mock.call.makefile("rb"), - mock.call.makefile_stderr("rb"), - mock.call.exec_command("{val}\n".format(val=print_stdin)), - ) - ) - - def test_013_check_stdin_bytearray(self, client, policy, logger): - stdin_val = bytearray(b"this is a line") - - stdin = mock.Mock(name="stdin") - stdin_channel = mock.Mock() - stdin_channel.configure_mock(closed=False) - stdin.attach_mock(stdin_channel, "channel") - - stdout = mock.Mock(name="stdout") - stdout_channel = mock.Mock() - stdout_channel.configure_mock(closed=False) - stdout.attach_mock(stdout_channel, "channel") - - chan = mock.Mock() - chan.attach_mock(mock.Mock(side_effect=[stdin, stdout]), "makefile") - - open_session = mock.Mock(return_value=chan) - - transport = mock.Mock() - transport.attach_mock(open_session, "open_session") - get_transport = mock.Mock(return_value=transport) - _ssh = mock.Mock() - _ssh.attach_mock(get_transport, "get_transport") - client.return_value = _ssh - - ssh = self.get_ssh() - - # noinspection PyTypeChecker - result = ssh.execute_async(command=print_stdin, stdin=stdin_val) - - get_transport.assert_called_once() - open_session.assert_called_once() - stdin.assert_has_calls([mock.call.write("{val}\n".format(val=stdin_val)), mock.call.flush()]) - - self.assertIn(chan, result) - chan.assert_has_calls( - ( - mock.call.makefile("wb"), - mock.call.makefile("rb"), - mock.call.makefile_stderr("rb"), - mock.call.exec_command("{val}\n".format(val=print_stdin)), - ) - ) - - def test_014_check_stdin_closed(self, client, policy, logger): - stdin_val = "this is a line" - - stdin = mock.Mock(name="stdin") - stdin_channel = mock.Mock() - stdin_channel.configure_mock(closed=True) - stdin.attach_mock(stdin_channel, "channel") - - stdout = mock.Mock(name="stdout") - stdout_channel = mock.Mock() - stdout_channel.configure_mock(closed=False) - stdout.attach_mock(stdout_channel, "channel") - - chan = mock.Mock() - chan.attach_mock(mock.Mock(side_effect=[stdin, stdout]), "makefile") - - open_session = mock.Mock(return_value=chan) - - transport = mock.Mock() - transport.attach_mock(open_session, "open_session") - get_transport = mock.Mock(return_value=transport) - _ssh = mock.Mock() - _ssh.attach_mock(get_transport, "get_transport") - client.return_value = _ssh - - ssh = self.get_ssh() - - # noinspection PyTypeChecker - result = ssh.execute_async(command=print_stdin, stdin=stdin_val) - - get_transport.assert_called_once() - open_session.assert_called_once() - stdin.assert_not_called() - - log = logger(ssh.__class__.__name__).getChild("{host}:{port}".format(host=host, port=port)) - log.warning.assert_called_once_with("STDIN Send failed: closed channel") - - self.assertIn(chan, result) - chan.assert_has_calls( - ( - mock.call.makefile("wb"), - mock.call.makefile("rb"), - mock.call.makefile_stderr("rb"), - mock.call.exec_command("{val}\n".format(val=print_stdin)), - ) - ) - - def test_015_keepalive(self, client, policy, logger): - chan = mock.Mock() - open_session = mock.Mock(return_value=chan) - transport = mock.Mock() - transport.attach_mock(open_session, "open_session") - get_transport = mock.Mock(return_value=transport) - _ssh = mock.Mock() - _ssh.attach_mock(get_transport, "get_transport") - client.return_value = _ssh - - ssh = self.get_ssh() - - with ssh: - pass - - _ssh.close.assert_not_called() - - def test_016_no_keepalive(self, client, policy, logger): - chan = mock.Mock() - open_session = mock.Mock(return_value=chan) - transport = mock.Mock() - transport.attach_mock(open_session, "open_session") - get_transport = mock.Mock(return_value=transport) - _ssh = mock.Mock() - _ssh.attach_mock(get_transport, "get_transport") - client.return_value = _ssh - - ssh = self.get_ssh() - ssh.keepalive_mode = False - - with ssh: - pass - - _ssh.close.assert_called_once() - - def test_017_keepalive_enforced(self, client, policy, logger): - chan = mock.Mock() - open_session = mock.Mock(return_value=chan) - transport = mock.Mock() - transport.attach_mock(open_session, "open_session") - get_transport = mock.Mock(return_value=transport) - _ssh = mock.Mock() - _ssh.attach_mock(get_transport, "get_transport") - client.return_value = _ssh - - ssh = self.get_ssh() - ssh.keepalive_mode = False - - with ssh.keepalive(): - pass - - _ssh.close.assert_not_called() - - def test_018_no_keepalive_enforced(self, client, policy, logger): - chan = mock.Mock() - open_session = mock.Mock(return_value=chan) - transport = mock.Mock() - transport.attach_mock(open_session, "open_session") - get_transport = mock.Mock(return_value=transport) - _ssh = mock.Mock() - _ssh.attach_mock(get_transport, "get_transport") - client.return_value = _ssh - - ssh = self.get_ssh() - - with ssh.keepalive(enforce=False): - pass - - _ssh.close.assert_called_once() - - @staticmethod - def get_patched_execute_async_retval(ec=0, stderr_val=None, open_stdout=True, open_stderr=True, cmd_log=None): - """get patched execute_async retval - - :rtype: - Tuple( - mock.Mock, - str, - exec_result.ExecResult, - FakeStream, - FakeStream) - """ - if open_stdout: - out = stdout_list - stdout = FakeStream(*out) - else: - stdout = out = None - if open_stderr: - err = stderr_list if stderr_val is None else [] - stderr = FakeStream(*err) - else: - stderr = err = None - - exit_code = ec - chan = mock.Mock() - chan.attach_mock(mock.Mock(return_value=exit_code), "recv_exit_status") - - status_event = mock.Mock() - status_event.attach_mock(mock.Mock(), "wait") - chan.attach_mock(status_event, "status_event") - chan.configure_mock(exit_status=exit_code) - - # noinspection PyTypeChecker - exp_result = exec_result.ExecResult( - cmd=cmd_log if cmd_log is not None else command, stderr=err, stdout=out, exit_code=ec - ) - - return chan, "", exp_result, stderr, stdout - - @mock.patch("exec_helpers.ssh_client.SSHClient.execute_async") - def test_019_execute(self, execute_async, client, policy, logger): - (chan, _stdin, exp_result, stderr, stdout) = self.get_patched_execute_async_retval() - is_set = mock.Mock(return_value=True) - chan.status_event.attach_mock(is_set, "is_set") - - execute_async.return_value = exec_helpers.SshExecuteAsyncResult(chan, _stdin, stderr, stdout) - - ssh = self.get_ssh() - - logger.reset_mock() - - # noinspection PyTypeChecker - result = ssh.execute(command=command, verbose=False) - - self.assertEqual(result, exp_result) - execute_async.assert_called_once_with(command, verbose=False) - chan.assert_has_calls((mock.call.status_event.is_set(),)) - message = self.gen_cmd_result_log_message(result) - log = logger(ssh.__class__.__name__).getChild("{host}:{port}".format(host=host, port=port)).log - log.assert_has_calls( - [mock.call(level=logging.DEBUG, msg=str(x.rstrip().decode("utf-8"))) for x in stdout_list] - + [mock.call(level=logging.DEBUG, msg=str(x.rstrip().decode("utf-8"))) for x in stderr_list] - + [mock.call(level=logging.DEBUG, msg=message)] - ) - - @mock.patch("exec_helpers.ssh_client.SSHClient.execute_async") - def test_020_execute_verbose(self, execute_async, client, policy, logger): - (chan, _stdin, exp_result, stderr, stdout) = self.get_patched_execute_async_retval() - is_set = mock.Mock(return_value=True) - chan.status_event.attach_mock(is_set, "is_set") - - execute_async.return_value = exec_helpers.SshExecuteAsyncResult(chan, _stdin, stderr, stdout) - - ssh = self.get_ssh() - - logger.reset_mock() - - # noinspection PyTypeChecker - result = ssh.execute(command=command, verbose=True) - - self.assertEqual(result, exp_result) - execute_async.assert_called_once_with(command, verbose=True) - chan.assert_has_calls((mock.call.status_event.is_set(),)) - - message = self.gen_cmd_result_log_message(result) - log = logger(ssh.__class__.__name__).getChild("{host}:{port}".format(host=host, port=port)).log - log.assert_has_calls( - [mock.call(level=logging.INFO, msg=str(x.rstrip().decode("utf-8"))) for x in stdout_list] - + [mock.call(level=logging.INFO, msg=str(x.rstrip().decode("utf-8"))) for x in stderr_list] - + [mock.call(level=logging.INFO, msg=message)] - ) - - @mock.patch("exec_helpers.ssh_client.SSHClient.execute_async") - def test_021_execute_no_stdout(self, execute_async, client, policy, logger): - (chan, _stdin, exp_result, stderr, stdout) = self.get_patched_execute_async_retval(open_stdout=False) - chan.status_event.attach_mock(mock.Mock(return_value=True), "is_set") - - execute_async.return_value = exec_helpers.SshExecuteAsyncResult(chan, _stdin, stderr, stdout) - - ssh = self.get_ssh() - - logger.reset_mock() - - # noinspection PyTypeChecker - result = ssh.execute(command=command, verbose=False, open_stdout=False) - - self.assertEqual(result, exp_result) - execute_async.assert_called_once_with(command, verbose=False, open_stdout=False) - message = self.gen_cmd_result_log_message(result) - log = logger(ssh.__class__.__name__).getChild("{host}:{port}".format(host=host, port=port)).log - log.assert_has_calls( - [mock.call(level=logging.DEBUG, msg=str(x.rstrip().decode("utf-8"))) for x in stderr_list] - + [mock.call(level=logging.DEBUG, msg=message)] - ) - - @mock.patch("exec_helpers.ssh_client.SSHClient.execute_async") - def test_022_execute_no_stderr(self, execute_async, client, policy, logger): - (chan, _stdin, exp_result, stderr, stdout) = self.get_patched_execute_async_retval(open_stderr=False) - chan.status_event.attach_mock(mock.Mock(return_value=True), "is_set") - - execute_async.return_value = exec_helpers.SshExecuteAsyncResult(chan, _stdin, stderr, stdout) - - ssh = self.get_ssh() - - logger.reset_mock() - - # noinspection PyTypeChecker - result = ssh.execute(command=command, verbose=False, open_stderr=False) - - self.assertEqual(result, exp_result) - execute_async.assert_called_once_with(command, verbose=False, open_stderr=False) - message = self.gen_cmd_result_log_message(result) - log = logger(ssh.__class__.__name__).getChild("{host}:{port}".format(host=host, port=port)).log - log.assert_has_calls( - [mock.call(level=logging.DEBUG, msg=str(x.rstrip().decode("utf-8"))) for x in stdout_list] - + [mock.call(level=logging.DEBUG, msg=message)] - ) - - @mock.patch("exec_helpers.ssh_client.SSHClient.execute_async") - def test_023_execute_no_stdout_stderr(self, execute_async, client, policy, logger): - (chan, _stdin, exp_result, stderr, stdout) = self.get_patched_execute_async_retval( - open_stdout=False, open_stderr=False - ) - chan.status_event.attach_mock(mock.Mock(return_value=True), "is_set") - - execute_async.return_value = exec_helpers.SshExecuteAsyncResult(chan, _stdin, stderr, stdout) - - ssh = self.get_ssh() - - logger.reset_mock() - - # noinspection PyTypeChecker - result = ssh.execute(command=command, verbose=False, open_stdout=False, open_stderr=False) - - self.assertEqual(result, exp_result) - execute_async.assert_called_once_with(command, verbose=False, open_stdout=False, open_stderr=False) - message = self.gen_cmd_result_log_message(result) - log = logger(ssh.__class__.__name__).getChild("{host}:{port}".format(host=host, port=port)).log - log.assert_has_calls([mock.call(level=logging.DEBUG, msg=message)]) - - @mock.patch("time.sleep", autospec=True) - @mock.patch("exec_helpers.ssh_client.SSHClient.execute_async") - def test_024_execute_timeout(self, execute_async, sleep, client, policy, logger): - (chan, _stdin, exp_result, stderr, stdout) = self.get_patched_execute_async_retval() - is_set = mock.Mock(return_value=True) - chan.status_event.attach_mock(is_set, "is_set") - - execute_async.return_value = exec_helpers.SshExecuteAsyncResult(chan, _stdin, stderr, stdout) - - ssh = self.get_ssh() - - logger.reset_mock() - - # noinspection PyTypeChecker - result = ssh.execute(command=command, verbose=False, timeout=0.2) - - self.assertEqual(result, exp_result) - execute_async.assert_called_once_with(command, verbose=False) - chan.assert_has_calls((mock.call.status_event.is_set(),)) - message = self.gen_cmd_result_log_message(result) - log = logger(ssh.__class__.__name__).getChild("{host}:{port}".format(host=host, port=port)).log - self.assertIn(mock.call(level=logging.DEBUG, msg=message), log.mock_calls) - - @mock.patch("time.sleep", autospec=True) - @mock.patch("exec_helpers.ssh_client.SSHClient.execute_async") - def test_025_execute_timeout_fail(self, execute_async, sleep, client, policy, logger): - (chan, _stdin, _, stderr, stdout) = self.get_patched_execute_async_retval() - is_set = mock.Mock(return_value=False) - chan.status_event.attach_mock(is_set, "is_set") - chan.status_event.attach_mock(mock.Mock(), "wait") - - execute_async.return_value = exec_helpers.SshExecuteAsyncResult(chan, _stdin, stderr, stdout) - - ssh = self.get_ssh() - - logger.reset_mock() - - with self.assertRaises(exec_helpers.ExecHelperTimeoutError) as cm: - # noinspection PyTypeChecker - ssh.execute(command=command, verbose=False, timeout=0.2) - - self.assertEqual(cm.exception.timeout, 0.2) - self.assertEqual(cm.exception.cmd, command) - self.assertEqual(cm.exception.stdout, stdout_str) - self.assertEqual(cm.exception.stderr, stderr_str) - - execute_async.assert_called_once_with(command, verbose=False) - chan.assert_has_calls((mock.call.status_event.is_set(),)) - - @mock.patch("exec_helpers.ssh_client.SSHClient.execute_async") - def test_026_execute_mask_command(self, execute_async, client, policy, logger): - cmd = "USE='secret=secret_pass' do task" - log_mask_re = r"secret\s*=\s*([A-Z-a-z0-9_\-]+)" - masked_cmd = "USE='secret=<*masked*>' do task" - - (chan, _stdin, exp_result, stderr, stdout) = self.get_patched_execute_async_retval(cmd_log=masked_cmd) - is_set = mock.Mock(return_value=True) - chan.status_event.attach_mock(is_set, "is_set") - - execute_async.return_value = exec_helpers.SshExecuteAsyncResult(chan, _stdin, stderr, stdout) - - ssh = self.get_ssh() - - logger.reset_mock() - - # noinspection PyTypeChecker - result = ssh.execute(command=cmd, verbose=False, log_mask_re=log_mask_re) - - self.assertEqual(result, exp_result) - execute_async.assert_called_once_with(cmd, log_mask_re=log_mask_re, verbose=False) - chan.assert_has_calls((mock.call.status_event.is_set(),)) - message = self.gen_cmd_result_log_message(result) - log = logger(ssh.__class__.__name__).getChild("{host}:{port}".format(host=host, port=port)).log - log.assert_has_calls( - [mock.call(level=logging.DEBUG, msg=str(x.rstrip().decode("utf-8"))) for x in stdout_list] - + [mock.call(level=logging.DEBUG, msg=str(x.rstrip().decode("utf-8"))) for x in stderr_list] - + [mock.call(level=logging.DEBUG, msg=message)] - ) - - @mock.patch("exec_helpers.ssh_client.SSHClient.execute_async") - def test_027_execute_together(self, execute_async, client, policy, logger): - (chan, _stdin, _, stderr, stdout) = self.get_patched_execute_async_retval() - execute_async.return_value = exec_helpers.SshExecuteAsyncResult(chan, _stdin, stderr, stdout) - - host2 = "127.0.0.2" - - ssh = self.get_ssh() - # noinspection PyTypeChecker - ssh2 = exec_helpers.SSHClient( - host=host2, port=port, auth=exec_helpers.SSHAuth(username=username, password=password) - ) - - remotes = [ssh, ssh2] - - # noinspection PyTypeChecker - results = exec_helpers.SSHClient.execute_together(remotes=remotes, command=command) - - self.assertEqual(execute_async.call_count, len(remotes)) - self.assertEqual( - sorted(chan.mock_calls), - sorted( - ( - mock.call.status_event.wait(constants.DEFAULT_TIMEOUT), - mock.call.recv_exit_status(), - mock.call.close(), - mock.call.status_event.wait(constants.DEFAULT_TIMEOUT), - mock.call.recv_exit_status(), - mock.call.close(), - ) - ), - ) - self.assertIn((ssh.hostname, ssh.port), results) - self.assertIn((ssh2.hostname, ssh2.port), results) - for result in results.values(): # type: exec_result.ExecResult - self.assertEqual(result.cmd, command) - - # noinspection PyTypeChecker - exec_helpers.SSHClient.execute_together(remotes=remotes, command=command, expected=[1], raise_on_err=False) - - with self.assertRaises(exec_helpers.ParallelCallProcessError): - # noinspection PyTypeChecker - exec_helpers.SSHClient.execute_together(remotes=remotes, command=command, expected=[1]) - - @mock.patch("exec_helpers.ssh_client.SSHClient.execute_async") - def test_028_execute_together_exceptions( - self, execute_async: mock.Mock, client: mock.Mock, policy: mock.Mock, logger: mock.Mock - ) -> None: - """Simple scenario: execute_async fail on all nodes.""" - execute_async.side_effect = RuntimeError - - host2 = "127.0.0.2" - - ssh = self.get_ssh() - # noinspection PyTypeChecker - ssh2 = exec_helpers.SSHClient( - host=host2, port=port, auth=exec_helpers.SSHAuth(username=username, password=password) - ) - - remotes = [ssh, ssh2] - - # noinspection PyTypeChecker - with self.assertRaises(exec_helpers.ParallelCallExceptions) as cm: - exec_helpers.SSHClient.execute_together(remotes=remotes, command=command) - - exc = cm.exception # type: exec_helpers.ParallelCallExceptions - self.assertEqual(list(sorted(exc.exceptions)), [(host, port), (host2, port)]) - for exception in exc.exceptions.values(): - self.assertIsInstance(exception, RuntimeError) - - @mock.patch("exec_helpers.ssh_client.SSHClient.execute") - def test_029_check_call(self, execute, client, policy, logger): - exit_code = 0 - return_value = exec_result.ExecResult(cmd=command, stdout=stdout_list, stderr=stderr_list, exit_code=exit_code) - execute.return_value = return_value - - verbose = False - - ssh = self.get_ssh() - - # noinspection PyTypeChecker - result = ssh.check_call(command=command, verbose=verbose, timeout=None) - execute.assert_called_once_with(command, verbose, None) - self.assertEqual(result, return_value) - - exit_code = 1 - execute.reset_mock() - return_value = exec_result.ExecResult(cmd=command, stdout=stdout_list, stderr=stderr_list, exit_code=exit_code) - execute.return_value = return_value - with self.assertRaises(exec_helpers.CalledProcessError) as cm: - # noinspection PyTypeChecker - ssh.check_call(command=command, verbose=verbose, timeout=None) - exc = cm.exception - self.assertEqual(exc.cmd, command) - self.assertEqual(exc.returncode, 1) - self.assertEqual(exc.stdout, stdout_str) - self.assertEqual(exc.stderr, stderr_str) - execute.assert_called_once_with(command, verbose, None) - - @mock.patch("exec_helpers.ssh_client.SSHClient.execute") - def test_030_check_call_expected(self, execute, client, policy, logger): - exit_code = 0 - return_value = exec_result.ExecResult(cmd=command, stdout=stdout_list, stderr=stderr_list, exit_code=exit_code) - execute.return_value = return_value - - verbose = False - - ssh = self.get_ssh() - - # noinspection PyTypeChecker - result = ssh.check_call(command=command, verbose=verbose, timeout=None, expected=[0, 75]) - execute.assert_called_once_with(command, verbose, None) - self.assertEqual(result, return_value) - - exit_code = 1 - return_value = exec_result.ExecResult(cmd=command, stdout=stdout_list, stderr=stderr_list, exit_code=exit_code) - execute.reset_mock() - execute.return_value = return_value - with self.assertRaises(exec_helpers.CalledProcessError): - # noinspection PyTypeChecker - ssh.check_call(command=command, verbose=verbose, timeout=None, expected=[0, 75]) - execute.assert_called_once_with(command, verbose, None) - - @mock.patch("exec_helpers.ssh_client.SSHClient.check_call") - def test_031_check_stderr(self, check_call, client, policy, logger): - return_value = exec_result.ExecResult(cmd=command, stdout=stdout_list, stderr=[], exit_code=0) - check_call.return_value = return_value - - verbose = False - raise_on_err = True - - ssh = self.get_ssh() - - # noinspection PyTypeChecker - result = ssh.check_stderr(command=command, verbose=verbose, timeout=None, raise_on_err=raise_on_err) - check_call.assert_called_once_with(command, verbose, timeout=None, error_info=None, raise_on_err=raise_on_err) - self.assertEqual(result, return_value) - - return_value = exec_result.ExecResult(cmd=command, stdout=stdout_list, stderr=stderr_list, exit_code=0) - - check_call.reset_mock() - check_call.return_value = return_value - with self.assertRaises(exec_helpers.CalledProcessError): - # noinspection PyTypeChecker - ssh.check_stderr(command=command, verbose=verbose, timeout=None, raise_on_err=raise_on_err) - check_call.assert_called_once_with(command, verbose, timeout=None, error_info=None, raise_on_err=raise_on_err) - - -@mock.patch("logging.getLogger", autospec=True) -@mock.patch("paramiko.AutoAddPolicy", autospec=True, return_value="AutoAddPolicy") -@mock.patch("paramiko.SSHClient", autospec=True) -@mock.patch("paramiko.Transport", autospec=True) -class TestExecuteThrowHost(unittest.TestCase): - def tearDown(self): - with mock.patch("warnings.warn"): - exec_helpers.SSHClient._clear_cache() - - @staticmethod - def prepare_execute_through_host(transp, client, exit_code): - intermediate_channel = mock.Mock(name="intermediate_channel") - - open_channel = mock.Mock(return_value=intermediate_channel, name="open_channel") - intermediate_transport = mock.Mock(name="intermediate_transport", spec="paramiko.transport.Transport") - intermediate_transport.attach_mock(open_channel, "open_channel") - get_transport = mock.Mock(return_value=intermediate_transport, name="get_transport") - - _ssh = mock.Mock(neme="_ssh") - _ssh.attach_mock(get_transport, "get_transport") - client.return_value = _ssh - - transport = mock.Mock(name="transport") - transp.return_value = transport - - recv_exit_status = mock.Mock(return_value=exit_code) - - channel = mock.Mock() - channel.attach_mock(mock.Mock(return_value=FakeStream(b" \n", b"2\n", b"3\n", b" \n")), "makefile") - channel.attach_mock(mock.Mock(return_value=FakeStream(b" \n", b"0\n", b"1\n", b" \n")), "makefile_stderr") - - channel.attach_mock(recv_exit_status, "recv_exit_status") - open_session = mock.Mock(return_value=channel, name="open_session") - transport.attach_mock(open_session, "open_session") - - wait = mock.Mock() - status_event = mock.Mock() - status_event.attach_mock(wait, "wait") - channel.attach_mock(status_event, "status_event") - channel.configure_mock(exit_status=exit_code) - - is_set = mock.Mock(return_value=True) - channel.status_event.attach_mock(is_set, "is_set") - - return (open_session, transport, channel, get_transport, open_channel, intermediate_channel) - - def test_01_execute_through_host_no_creds(self, transp, client, policy, logger): - target = "127.0.0.2" - exit_code = 0 - - # noinspection PyTypeChecker - return_value = exec_result.ExecResult(cmd=command, stderr=stderr_list, stdout=stdout_list, exit_code=exit_code) - - ( - open_session, - transport, - channel, - get_transport, - open_channel, - intermediate_channel, - ) = self.prepare_execute_through_host(transp=transp, client=client, exit_code=exit_code) - - # noinspection PyTypeChecker - ssh = exec_helpers.SSHClient( - host=host, port=port, auth=exec_helpers.SSHAuth(username=username, password=password) - ) - - # noinspection PyTypeChecker - result = ssh.execute_through_host(target, command) - self.assertEqual(result, return_value) - get_transport.assert_called_once() - open_channel.assert_called_once() - transp.assert_called_once_with(intermediate_channel) - open_session.assert_called_once() - transport.assert_has_calls( - (mock.call.connect(username=username, password=password, pkey=None), mock.call.open_session()) - ) - channel.assert_has_calls( - ( - mock.call.makefile("rb"), - mock.call.makefile_stderr("rb"), - mock.call.exec_command(command), - mock.call.recv_ready(), - mock.call.recv_stderr_ready(), - mock.call.status_event.is_set(), - mock.call.close(), - ) - ) - - def test_02_execute_through_host_auth(self, transp, client, policy, logger): - _login = "cirros" - _password = "cubswin:)" - - target = "127.0.0.2" - exit_code = 0 - - # noinspection PyTypeChecker - return_value = exec_result.ExecResult(cmd=command, stderr=stderr_list, stdout=stdout_list, exit_code=exit_code) - - ( - open_session, - transport, - channel, - get_transport, - open_channel, - intermediate_channel, - ) = self.prepare_execute_through_host(transp, client, exit_code=exit_code) - - # noinspection PyTypeChecker - ssh = exec_helpers.SSHClient( - host=host, - port=port, - auth=exec_helpers.SSHAuth( - username=username, password=password, key_filename="~/fake_key", passphrase="fake_passphrase" - ), - ) - - # noinspection PyTypeChecker - result = ssh.execute_through_host( - target, command, auth=exec_helpers.SSHAuth(username=_login, password=_password) - ) - self.assertEqual(result, return_value) - get_transport.assert_called_once() - open_channel.assert_called_once() - transp.assert_called_once_with(intermediate_channel) - open_session.assert_called_once() - transport.assert_has_calls( - (mock.call.connect(username=_login, password=_password, pkey=None), mock.call.open_session()) - ) - channel.assert_has_calls( - ( - mock.call.makefile("rb"), - mock.call.makefile_stderr("rb"), - mock.call.exec_command(command), - mock.call.recv_ready(), - mock.call.recv_stderr_ready(), - mock.call.status_event.is_set(), - mock.call.close(), - ) - ) - - def test_03_execute_through_host_get_pty(self, transp, client, policy, logger): - target = "127.0.0.2" - exit_code = 0 - - # noinspection PyTypeChecker - return_value = exec_result.ExecResult(cmd=command, stderr=stderr_list, stdout=stdout_list, exit_code=exit_code) - - ( - open_session, - transport, - channel, - get_transport, - open_channel, - intermediate_channel, - ) = self.prepare_execute_through_host(transp=transp, client=client, exit_code=exit_code) - - # noinspection PyTypeChecker - ssh = exec_helpers.SSHClient( - host=host, port=port, auth=exec_helpers.SSHAuth(username=username, password=password) - ) - - # noinspection PyTypeChecker - result = ssh.execute_through_host(target, command, get_pty=True) - self.assertEqual(result, return_value) - get_transport.assert_called_once() - open_channel.assert_called_once() - transp.assert_called_once_with(intermediate_channel) - open_session.assert_called_once() - transport.assert_has_calls( - (mock.call.connect(username=username, password=password, pkey=None), mock.call.open_session()) - ) - - channel.assert_has_calls( - ( - mock.call.get_pty(term="vt100", width=80, height=24, width_pixels=0, height_pixels=0), - mock.call.makefile("rb"), - mock.call.makefile_stderr("rb"), - mock.call.exec_command(command), - mock.call.recv_ready(), - mock.call.recv_stderr_ready(), - mock.call.status_event.is_set(), - mock.call.close(), - ) - ) - - -@mock.patch("logging.getLogger", autospec=True) -@mock.patch("paramiko.AutoAddPolicy", autospec=True, return_value="AutoAddPolicy") -@mock.patch("paramiko.SSHClient", autospec=True) -class TestSftp(unittest.TestCase): - def tearDown(self): - with mock.patch("warnings.warn"): - exec_helpers.SSHClient._clear_cache() - - @staticmethod - def prepare_sftp_file_tests(client): - _ssh = mock.Mock() - client.return_value = _ssh - _sftp = mock.Mock() - open_sftp = mock.Mock(parent=_ssh, return_value=_sftp) - _ssh.attach_mock(open_sftp, "open_sftp") - - # noinspection PyTypeChecker - ssh = exec_helpers.SSHClient( - host=host, port=port, auth=exec_helpers.SSHAuth(username=username, password=password) - ) - return ssh, _sftp - - def test_exists(self, client, *args): - ssh, _sftp = self.prepare_sftp_file_tests(client) - lstat = mock.Mock() - _sftp.attach_mock(lstat, "lstat") - dst = "/etc" - - # noinspection PyTypeChecker - result = ssh.exists(dst) - self.assertTrue(result) - lstat.assert_called_once_with(dst) - - # Negative scenario - lstat.reset_mock() - lstat.side_effect = IOError - - # noinspection PyTypeChecker - result = ssh.exists(dst) - self.assertFalse(result) - lstat.assert_called_once_with(dst) - - def test_stat(self, client, *args): - ssh, _sftp = self.prepare_sftp_file_tests(client) - stat = mock.Mock() - _sftp.attach_mock(stat, "stat") - stat.return_value = paramiko.sftp_attr.SFTPAttributes() - stat.return_value.st_size = 0 - stat.return_value.st_uid = 0 - stat.return_value.st_gid = 0 - dst = "/etc/passwd" - - # noinspection PyTypeChecker - result = ssh.stat(dst) - self.assertEqual(result.st_size, 0) - self.assertEqual(result.st_uid, 0) - self.assertEqual(result.st_gid, 0) - - def test_isfile(self, client, *args): - class Attrs: - def __init__(self, mode): - self.st_mode = mode - - ssh, _sftp = self.prepare_sftp_file_tests(client) - lstat = mock.Mock() - _sftp.attach_mock(lstat, "lstat") - lstat.return_value = Attrs(stat.S_IFREG) - dst = "/etc/passwd" - - # noinspection PyTypeChecker - result = ssh.isfile(dst) - self.assertTrue(result) - lstat.assert_called_once_with(dst) - - # Negative scenario - lstat.reset_mock() - lstat.return_value = Attrs(stat.S_IFDIR) - - # noinspection PyTypeChecker - result = ssh.isfile(dst) - self.assertFalse(result) - lstat.assert_called_once_with(dst) - - lstat.reset_mock() - lstat.side_effect = IOError - - # noinspection PyTypeChecker - result = ssh.isfile(dst) - self.assertFalse(result) - lstat.assert_called_once_with(dst) - - def test_isdir(self, client, *args): - class Attrs: - def __init__(self, mode): - self.st_mode = mode - - ssh, _sftp = self.prepare_sftp_file_tests(client) - lstat = mock.Mock() - _sftp.attach_mock(lstat, "lstat") - lstat.return_value = Attrs(stat.S_IFDIR) - dst = "/etc/passwd" - - # noinspection PyTypeChecker - result = ssh.isdir(dst) - self.assertTrue(result) - lstat.assert_called_once_with(dst) - - # Negative scenario - lstat.reset_mock() - lstat.return_value = Attrs(stat.S_IFREG) - - # noinspection PyTypeChecker - result = ssh.isdir(dst) - self.assertFalse(result) - lstat.assert_called_once_with(dst) - - lstat.reset_mock() - lstat.side_effect = IOError - # noinspection PyTypeChecker - result = ssh.isdir(dst) - self.assertFalse(result) - lstat.assert_called_once_with(dst) - - @mock.patch("exec_helpers.ssh_client.SSHClient.exists") - @mock.patch("exec_helpers.ssh_client.SSHClient.execute") - def test_mkdir(self, execute, exists, *args): - exists.side_effect = [False, True] - - dst = "~/tst dir" - escaped_dst = r"~/tst\ dir" - - # noinspection PyTypeChecker - ssh = exec_helpers.SSHClient( - host=host, port=port, auth=exec_helpers.SSHAuth(username=username, password=password) - ) - - # Path not exists - # noinspection PyTypeChecker - ssh.mkdir(dst) - exists.assert_called_once_with(dst) - execute.assert_called_once_with("mkdir -p {}\n".format(escaped_dst)) - - # Path exists - exists.reset_mock() - execute.reset_mock() - - # noinspection PyTypeChecker - ssh.mkdir(dst) - exists.assert_called_once_with(dst) - execute.assert_not_called() - - @mock.patch("exec_helpers.ssh_client.SSHClient.execute") - def test_rm_rf(self, execute, *args): - dst = "~/tst" - - # noinspection PyTypeChecker - ssh = exec_helpers.SSHClient( - host=host, port=port, auth=exec_helpers.SSHAuth(username=username, password=password) - ) - - # Path not exists - # noinspection PyTypeChecker - ssh.rm_rf(dst) - execute.assert_called_once_with("rm -rf {}".format(dst)) - - def test_open(self, client, *args): - ssh, _sftp = self.prepare_sftp_file_tests(client) - fopen = mock.Mock(return_value=True) - _sftp.attach_mock(fopen, "open") - - dst = "/etc/passwd" - mode = "r" - # noinspection PyTypeChecker - result = ssh.open(dst) - fopen.assert_called_once_with(dst, mode) - self.assertTrue(result) - - @mock.patch("exec_helpers.ssh_client.logger", autospec=True) - @mock.patch("exec_helpers.ssh_client.SSHClient.exists") - @mock.patch("os.path.exists", autospec=True) - @mock.patch("exec_helpers.ssh_client.SSHClient.isdir") - @mock.patch("os.path.isdir", autospec=True) - def test_download(self, isdir, remote_isdir, exists, remote_exists, logger, client, policy, _logger): - ssh, _sftp = self.prepare_sftp_file_tests(client) - isdir.return_value = True - exists.side_effect = [True, False, False] - remote_isdir.side_effect = [False, False, True] - remote_exists.side_effect = [True, False, False] - - dst = "/etc/environment" - target = "/tmp/environment" - # noinspection PyTypeChecker - result = ssh.download(destination=dst, target=target) - self.assertTrue(result) - isdir.assert_called_once_with(target) - exists.assert_called_once_with(posixpath.join(target, os.path.basename(dst))) - remote_isdir.assert_called_once_with(dst) - remote_exists.assert_called_once_with(dst) - _sftp.assert_has_calls((mock.call.get(dst, posixpath.join(target, os.path.basename(dst))),)) - - # Negative scenarios - # noinspection PyTypeChecker - result = ssh.download(destination=dst, target=target) - self.assertFalse(result) - - # noinspection PyTypeChecker - ssh.download(destination=dst, target=target) - - @mock.patch("exec_helpers.ssh_client.SSHClient.isdir") - @mock.patch("os.path.isdir", autospec=True) - def test_upload_file(self, isdir, remote_isdir, client, *args): - ssh, _sftp = self.prepare_sftp_file_tests(client) - isdir.return_value = False - remote_isdir.return_value = False - target = "/etc/environment" - source = "/tmp/environment" - - # noinspection PyTypeChecker - ssh.upload(source=source, target=target) - isdir.assert_called_once_with(source) - remote_isdir.assert_called_once_with(target) - _sftp.assert_has_calls((mock.call.put(source, target),)) - - @mock.patch("exec_helpers.ssh_client.SSHClient.exists") - @mock.patch("exec_helpers.ssh_client.SSHClient.mkdir") - @mock.patch("os.walk") - @mock.patch("exec_helpers.ssh_client.SSHClient.isdir") - @mock.patch("os.path.isdir", autospec=True) - def test_upload_dir(self, isdir, remote_isdir, walk, mkdir, exists, client, *args): - ssh, _sftp = self.prepare_sftp_file_tests(client) - isdir.return_value = True - remote_isdir.return_value = True - exists.return_value = True - target = "/etc" - source = "/tmp/bash" - filename = "bashrc" - walk.return_value = ((source, "", [filename]),) - expected_path = posixpath.join(target, os.path.basename(source)) - expected_file = posixpath.join(expected_path, filename) - - # noinspection PyTypeChecker - ssh.upload(source=source, target=target) - isdir.assert_called_once_with(source) - remote_isdir.assert_called_once_with(target) - mkdir.assert_called_once_with(expected_path) - exists.assert_called_once_with(expected_file) - _sftp.assert_has_calls( - ( - mock.call.unlink(expected_file), - mock.call.put(os.path.normpath(os.path.join(source, filename)), expected_file), - ) - ) diff --git a/test/test_ssh_client_execute.py b/test/test_ssh_client_execute.py new file mode 100644 index 0000000..03248c7 --- /dev/null +++ b/test/test_ssh_client_execute.py @@ -0,0 +1,508 @@ +# Copyright 2018 Alexey Stepanov aka penguinolog. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import logging +import typing + +import mock +import pytest + +import exec_helpers + + +class FakeFileStream: + """Mock-like object for stream emulation.""" + + def __init__(self, *args): + self.__src = list(args) + self.closed = False + + def __iter__(self): + """Normally we iter over source.""" + for _ in range(len(self.__src)): + yield self.__src.pop(0) + + def fileno(self): + return hash(tuple(self.__src)) + + def close(self): + """We enforce close.""" + self.closed = True + + +def read_stream(stream: FakeFileStream) -> typing.Tuple[bytes, ...]: + return tuple([line for line in stream]) + + +host = "127.0.0.1" +host2 = "127.0.0.2" +port = 22 +username = "user" +password = "pass" + +command = "ls ~\nline 2\nline 3\nline с кирилицей" +command_log = "Executing command:\n{!r}\n".format(command.rstrip()) + +print_stdin = 'read line; echo "$line"' +default_timeout = 60 * 60 # 1 hour + + +configs = { + "positive_simple": dict( + ec=0, stdout=(b" \n", b"2\n", b"3\n", b" \n"), stderr=(), stdin=None, open_stdout=True, open_stderr=True + ), + "with_pty": dict( + ec=0, + stdout=(b" \n", b"2\n", b"3\n", b" \n"), + stderr=(), + stdin=None, + open_stdout=True, + open_stderr=True, + get_pty=True, + ), + "with_pty_nonstandard": dict( + ec=0, + stdout=(b" \n", b"2\n", b"3\n", b" \n"), + stderr=(), + stdin=None, + open_stdout=True, + open_stderr=True, + get_pty=True, + width=120, + height=100, + ), + "with_stderr": dict( + ec=0, + stdout=(b" \n", b"2\n", b"3\n", b" \n"), + stderr=(b" \n", b"0\n", b"1\n", b" \n"), + stdin=None, + open_stdout=True, + open_stderr=True, + ), + "negative": dict( + ec=1, + stdout=(b" \n", b"2\n", b"3\n", b" \n"), + stderr=(b" \n", b"0\n", b"1\n", b" \n"), + stdin=None, + open_stdout=True, + open_stderr=True, + ), + "with_stdin_str": dict( + ec=0, stdout=(b" \n", b"2\n", b"3\n", b" \n"), stderr=(), stdin="stdin", open_stdout=True, open_stderr=True + ), + "with_stdin_bytes": dict( + ec=0, stdout=(b" \n", b"2\n", b"3\n", b" \n"), stderr=(), stdin=b"stdin", open_stdout=True, open_stderr=True + ), + "with_stdin_bytearray": dict( + ec=0, + stdout=(b" \n", b"2\n", b"3\n", b" \n"), + stderr=(), + stdin=bytearray(b"stdin"), + open_stdout=True, + open_stderr=True, + ), + "no_stderr": dict( + ec=0, stdout=(b" \n", b"2\n", b"3\n", b" \n"), stderr=(), stdin=None, open_stdout=True, open_stderr=False + ), + "no_stdout": dict(ec=0, stdout=(), stderr=(), stdin=None, open_stdout=False, open_stderr=False), +} + + +def pytest_generate_tests(metafunc): + if "run_parameters" in metafunc.fixturenames: + metafunc.parametrize( + "run_parameters", + [ + "positive_simple", + "with_pty", + "with_pty_nonstandard", + "with_stderr", + "negative", + "with_stdin_str", + "with_stdin_bytes", + "with_stdin_bytearray", + "no_stderr", + "no_stdout", + ], + indirect=True, + ) + + +@pytest.fixture +def run_parameters(request): + return configs[request.param] + + +@pytest.fixture +def auto_add_policy(mocker): + return mocker.patch("paramiko.AutoAddPolicy", return_value="AutoAddPolicy") + + +@pytest.fixture +def paramiko_ssh_client(mocker): + mocker.patch("time.sleep") + return mocker.patch("paramiko.SSHClient") + + +@pytest.fixture +def chan_makefile(run_parameters): + class MkFile: + def __init__(self): + self.stdin = None + self.stdout = None + self.channel = None + + def __call__(self, flags: str): + if "wb" == flags: + self.stdin = mock.Mock() + self.stdin.channel = self.channel + return self.stdin + elif "rb" == flags: + self.stdout = FakeFileStream(*run_parameters["stdout"]) + return self.stdout + raise ValueError("Unexpected flags: {!r}".format(flags)) + + return MkFile() + + +@pytest.fixture +def ssh_transport_channel(paramiko_ssh_client, chan_makefile, run_parameters): + chan = mock.Mock(makefile=chan_makefile, closed=False) + chan_makefile.channel = chan + if run_parameters["open_stderr"]: + chan.attach_mock(mock.Mock(return_value=FakeFileStream(*run_parameters["stderr"])), "makefile_stderr") + chan.configure_mock(exit_status=run_parameters["ec"]) + chan.attach_mock(mock.Mock(return_value=run_parameters["ec"]), "recv_exit_status") + chan.status_event.attach_mock(mock.Mock(return_value=True), "is_set") + open_session = mock.Mock(return_value=chan) + transport = mock.Mock() + transport.attach_mock(open_session, "open_session") + get_transport = mock.Mock(return_value=transport) + _ssh = mock.Mock() + _ssh.attach_mock(get_transport, "get_transport") + paramiko_ssh_client.return_value = _ssh + return chan + + +@pytest.fixture +def ssh_auth_logger(mocker): + return mocker.patch("exec_helpers.ssh_auth.logger") + + +@pytest.fixture +def get_logger(mocker): + return mocker.patch("logging.getLogger") + + +@pytest.fixture +def ssh(paramiko_ssh_client, ssh_transport_channel, auto_add_policy, ssh_auth_logger, get_logger): + return exec_helpers.SSHClient(host=host, port=port, auth=exec_helpers.SSHAuth(username=username, password=password)) + + +@pytest.fixture +def ssh2(paramiko_ssh_client, ssh_transport_channel, auto_add_policy, ssh_auth_logger, get_logger): + return exec_helpers.SSHClient( + host=host2, port=port, auth=exec_helpers.SSHAuth(username=username, password=password) + ) + + +@pytest.fixture +def exec_result(run_parameters): + return exec_helpers.ExecResult( + cmd=command, + stdin=run_parameters["stdin"], + stdout=tuple([line for line in run_parameters["stdout"]]) if run_parameters["stdout"] else None, + stderr=tuple([line for line in run_parameters["stderr"]]) if run_parameters["stderr"] else None, + exit_code=run_parameters["ec"], + ) + + +@pytest.fixture +def execute_async(mocker, run_parameters): + def get_patched_execute_async_retval( + ec=0, stdout=(), stderr=(), open_stdout=True, open_stderr=True, **kwargs + ) -> exec_helpers.SshExecuteAsyncResult: + stdout_part = FakeFileStream(*stdout) if open_stdout else None + stderr_part = FakeFileStream(*stderr) if open_stderr else None + + exit_code = ec + chan = mock.Mock() + chan.attach_mock(mock.Mock(return_value=exit_code), "recv_exit_status") + + status_event = mock.Mock() + status_event.attach_mock(mock.Mock(), "wait") + chan.attach_mock(status_event, "status_event") + chan.configure_mock(exit_status=exit_code) + return exec_helpers.SshExecuteAsyncResult( + interface=chan, stdin=mock.Mock, stdout=stdout_part, stderr=stderr_part + ) + + return mocker.patch( + "exec_helpers.ssh_client.SSHClient.execute_async", + side_effect=[ + get_patched_execute_async_retval(**run_parameters), + get_patched_execute_async_retval(**run_parameters), + ], + ) + + +@pytest.fixture +def execute(mocker, exec_result): + return mocker.patch("exec_helpers.ssh_client.SSHClient.execute", name="execute", return_value=exec_result) + + +def teardown_function(function): + """Clean-up after tests.""" + with mock.patch("warnings.warn"): + exec_helpers.SSHClient._clear_cache() + + +def test_001_execute_async(ssh, paramiko_ssh_client, ssh_transport_channel, chan_makefile, run_parameters, get_logger): + open_stdout = run_parameters["open_stdout"] + open_stderr = run_parameters["open_stderr"] + get_pty = run_parameters.get("get_pty", False) + + kwargs = {} + if "get_pty" in run_parameters: + kwargs["get_pty"] = get_pty + if "width" in run_parameters: + kwargs["width"] = run_parameters["width"] + if "height" in run_parameters: + kwargs["height"] = run_parameters["height"] + + res = ssh.execute_async( + command, stdin=run_parameters["stdin"], open_stdout=open_stdout, open_stderr=open_stderr, **kwargs + ) + assert isinstance(res, exec_helpers.SshExecuteAsyncResult) + assert res.interface is ssh_transport_channel + assert res.stdin is chan_makefile.stdin + assert res.stdout is chan_makefile.stdout + + paramiko_ssh_client.assert_has_calls( + ( + mock.call(), + mock.call().set_missing_host_key_policy("AutoAddPolicy"), + mock.call().connect(hostname="127.0.0.1", password="pass", pkey=None, port=22, username="user"), + mock.call().get_transport(), + ) + ) + + transport_calls = [] + if get_pty: + transport_calls.append( + mock.call.get_pty( + term="vt100", + width=run_parameters.get("width", 80), + height=run_parameters.get("height", 24), + width_pixels=0, + height_pixels=0, + ) + ) + if open_stderr: + transport_calls.append(mock.call.makefile_stderr("rb")) + transport_calls.append(mock.call.exec_command("{}\n".format(command))) + + ssh_transport_channel.assert_has_calls(transport_calls) + + stdout = run_parameters["stdout"] + stderr = run_parameters["stderr"] + + if open_stdout: + assert read_stream(res.stdout) == stdout + else: + assert res.stdout is None + if open_stderr: + assert read_stream(res.stderr) == stderr + else: + assert res.stderr is None + + if run_parameters["stdin"] is None: + stdin = None + elif isinstance(run_parameters["stdin"], bytes): + stdin = run_parameters["stdin"].decode("utf-8") + elif isinstance(run_parameters["stdin"], str): + stdin = run_parameters["stdin"] + else: + stdin = bytes(run_parameters["stdin"]).decode("utf-8") + + assert res.stdin.channel == res.interface + + if stdin: + res.stdin.write.assert_called_with("{stdin}\n".format(stdin=stdin).encode("utf-8")) + res.stdin.flush.assert_called_once() + log = get_logger(ssh.__class__.__name__).getChild("{host}:{port}".format(host=host, port=port)) + log.log.assert_called_once_with(level=logging.DEBUG, msg=command_log) + + +def test_002_execute(ssh, ssh_transport_channel, exec_result, run_parameters) -> None: + kwargs = {} + if "get_pty" in run_parameters: + kwargs["get_pty"] = run_parameters["get_pty"] + if "width" in run_parameters: + kwargs["width"] = run_parameters["width"] + if "height" in run_parameters: + kwargs["height"] = run_parameters["height"] + + res = ssh.execute( + command, + stdin=run_parameters["stdin"], + open_stdout=run_parameters["open_stdout"], + open_stderr=run_parameters["open_stderr"], + **kwargs + ) + assert isinstance(res, exec_helpers.ExecResult) + assert res == exec_result + ssh_transport_channel.assert_has_calls((mock.call.status_event.is_set(),)) + + +def test_003_context_manager(ssh, exec_result, run_parameters, mocker) -> None: + kwargs = {} + if "get_pty" in run_parameters: + kwargs["get_pty"] = run_parameters["get_pty"] + if "width" in run_parameters: + kwargs["width"] = run_parameters["width"] + if "height" in run_parameters: + kwargs["height"] = run_parameters["height"] + + with mocker.patch("threading.RLock") as lock: + with ssh: + res = ssh.execute( + command, + stdin=run_parameters["stdin"], + open_stdout=run_parameters["open_stdout"], + open_stderr=run_parameters["open_stderr"], + **kwargs + ) + lock.acquire_assert_called_once() + lock.release_assert_called_once() + assert isinstance(res, exec_helpers.ExecResult) + assert res == exec_result + + +def test_004_check_call(ssh, exec_result, get_logger, mocker) -> None: + mocker.patch("exec_helpers.ssh_client.SSHClient.execute", return_value=exec_result) + ssh_logger = get_logger(exec_helpers.SSHClient.__name__) + log = ssh_logger.getChild("{host}:{port}".format(host=host, port=port)) + + if exec_result.exit_code == exec_helpers.ExitCodes.EX_OK: + assert ssh.check_call(command, stdin=exec_result.stdin) == exec_result + else: + with pytest.raises(exec_helpers.CalledProcessError) as e: + ssh.check_call(command, stdin=exec_result.stdin) + + exc = e.value # type: exec_helpers.CalledProcessError + assert exc.cmd == exec_result.cmd + assert exc.returncode == exec_result.exit_code + assert exc.stdout == exec_result.stdout_str + assert exc.stderr == exec_result.stderr_str + assert exc.result == exec_result + assert exc.expected == [exec_helpers.ExitCodes.EX_OK] + + assert log.mock_calls[-1] == mock.call.error( + msg="Command {result.cmd!r} returned exit code {result.exit_code!s} while expected {expected!r}".format( + result=exc.result, expected=exc.expected + ) + ) + + +def test_005_check_call_no_raise(ssh, exec_result, get_logger, mocker) -> None: + mocker.patch("exec_helpers.ssh_client.SSHClient.execute", return_value=exec_result) + ssh_logger = get_logger(exec_helpers.SSHClient.__name__) + log = ssh_logger.getChild("{host}:{port}".format(host=host, port=port)) + + res = ssh.check_call(command, stdin=exec_result.stdin, raise_on_err=False) + assert res == exec_result + + if exec_result.exit_code != exec_helpers.ExitCodes.EX_OK: + assert log.mock_calls[-1] == mock.call.error( + msg="Command {result.cmd!r} returned exit code {result.exit_code!s} while expected {expected!r}".format( + result=res, expected=[exec_helpers.ExitCodes.EX_OK] + ) + ) + + +def test_006_check_call_expect(ssh, exec_result, mocker) -> None: + mocker.patch("exec_helpers.ssh_client.SSHClient.execute", return_value=exec_result) + assert ssh.check_call(command, stdin=exec_result.stdin, expected=[exec_result.exit_code]) == exec_result + + +def test_007_check_stderr(ssh, exec_result, get_logger, mocker) -> None: + mocker.patch("exec_helpers.ssh_client.SSHClient.check_call", return_value=exec_result) + ssh_logger = get_logger(exec_helpers.SSHClient.__name__) + log = ssh_logger.getChild("{host}:{port}".format(host=host, port=port)) + + if not exec_result.stderr: + assert ssh.check_stderr(command, stdin=exec_result.stdin, expected=[exec_result.exit_code]) == exec_result + else: + with pytest.raises(exec_helpers.CalledProcessError) as e: + ssh.check_stderr(command, stdin=exec_result.stdin, expected=[exec_result.exit_code]) + exc = e.value # type: exec_helpers.CalledProcessError + assert exc.result == exec_result + assert exc.cmd == exec_result.cmd + assert exc.returncode == exec_result.exit_code + assert exc.stdout == exec_result.stdout_str + assert exc.stderr == exec_result.stderr_str + assert exc.result == exec_result + + assert log.mock_calls[-1] == mock.call.error( + msg="Command {result.cmd!r} output contains STDERR while not expected\n" + "\texit code: {result.exit_code!s}".format(result=exc.result) + ) + + +def test_008_check_stderr_no_raise(ssh, exec_result, mocker) -> None: + mocker.patch("exec_helpers.ssh_client.SSHClient.check_call", return_value=exec_result) + assert ( + ssh.check_stderr(command, stdin=exec_result.stdin, expected=[exec_result.exit_code], raise_on_err=False) + == exec_result + ) + + +def test_009_execute_together(ssh, ssh2, execute_async, exec_result, run_parameters): + + remotes = [ssh, ssh2] + + if 0 == run_parameters["ec"]: + results = exec_helpers.SSHClient.execute_together( + remotes=remotes, command=command, stdin=run_parameters.get("stdin", None) + ) + execute_async.assert_has_calls( + ( + mock.call(command, stdin=run_parameters.get("stdin", None)), + mock.call(command, stdin=run_parameters.get("stdin", None)), + ) + ) + assert results == {(host, port): exec_result, (host2, port): exec_result} + else: + with pytest.raises(exec_helpers.ParallelCallProcessError) as e: + exec_helpers.SSHClient.execute_together(remotes=remotes, command=command) + exc = e.value # type: exec_helpers.ParallelCallProcessError + assert exc.cmd == command + assert exc.expected == [exec_helpers.ExitCodes.EX_OK] + assert exc.results == {(host, port): exec_result, (host2, port): exec_result} + + +def test_010_execute_together_expected(ssh, ssh2, execute_async, exec_result, run_parameters): + remotes = [ssh, ssh2] + + results = exec_helpers.SSHClient.execute_together( + remotes=remotes, command=command, stdin=run_parameters.get("stdin", None), expected=[run_parameters["ec"]] + ) + execute_async.assert_has_calls( + ( + mock.call(command, stdin=run_parameters.get("stdin", None)), + mock.call(command, stdin=run_parameters.get("stdin", None)), + ) + ) + assert results == {(host, port): exec_result, (host2, port): exec_result} diff --git a/test/test_ssh_client_execute_async_special.py b/test/test_ssh_client_execute_async_special.py new file mode 100644 index 0000000..e1b041a --- /dev/null +++ b/test/test_ssh_client_execute_async_special.py @@ -0,0 +1,255 @@ +# Copyright 2018 Alexey Stepanov aka penguinolog. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import base64 +import typing + +import mock +import pytest + +import exec_helpers + + +class FakeFileStream: + """Mock-like object for stream emulation.""" + + def __init__(self, *args): + self.__src = list(args) + self.closed = False + self.channel = None + + def __iter__(self): + """Normally we iter over source.""" + for _ in range(len(self.__src)): + yield self.__src.pop(0) + + def fileno(self): + return hash(tuple(self.__src)) + + def close(self): + """We enforce close.""" + self.closed = True + + +def read_stream(stream: FakeFileStream) -> typing.Tuple[bytes, ...]: + return tuple([line for line in stream]) + + +host = "127.0.0.1" +port = 22 +username = "user" +password = "pass" + +command = "ls ~\nline 2\nline 3\nline с кирилицей" +command_log = "Executing command:\n{!r}\n".format(command.rstrip()) +stdout_src = (b" \n", b"2\n", b"3\n", b" \n") +stderr_src = (b" \n", b"0\n", b"1\n", b" \n") +encoded_cmd = base64.b64encode("{}\n".format(command).encode("utf-8")).decode("utf-8") + +print_stdin = 'read line; echo "$line"' +default_timeout = 60 * 60 # 1 hour + + +@pytest.fixture +def auto_add_policy(mocker): + return mocker.patch("paramiko.AutoAddPolicy", return_value="AutoAddPolicy") + + +@pytest.fixture +def paramiko_ssh_client(mocker): + mocker.patch("time.sleep") + return mocker.patch("paramiko.SSHClient") + + +@pytest.fixture +def chan_makefile(): + class MkFile: + def __init__(self): + self.stdin = None + self.stdout = None + self.channel = None + + def __call__(self, flags: str): + if "wb" == flags: + self.stdin = mock.Mock() + self.stdin.channel = self.channel + return self.stdin + elif "rb" == flags: + self.stdout = FakeFileStream(*stdout_src) + self.stdout.channel = self.channel + return self.stdout + raise ValueError("Unexpected flags: {!r}".format(flags)) + + return MkFile() + + +@pytest.fixture +def ssh_transport_channel(paramiko_ssh_client, chan_makefile): + chan = mock.Mock(makefile=chan_makefile, closed=False) + chan_makefile.channel = chan + chan.attach_mock(mock.Mock(return_value=FakeFileStream(*stderr_src)), "makefile_stderr") + chan.configure_mock(exit_status=0) + chan.status_event.attach_mock(mock.Mock(return_value=True), "is_set") + open_session = mock.Mock(return_value=chan) + transport = mock.Mock() + transport.attach_mock(open_session, "open_session") + get_transport = mock.Mock(return_value=transport) + _ssh = mock.Mock() + _ssh.attach_mock(get_transport, "get_transport") + paramiko_ssh_client.return_value = _ssh + return chan + + +@pytest.fixture +def ssh_auth_logger(mocker): + return mocker.patch("exec_helpers.ssh_auth.logger") + + +@pytest.fixture +def get_logger(mocker): + return mocker.patch("logging.getLogger") + + +@pytest.fixture +def ssh(paramiko_ssh_client, ssh_transport_channel, auto_add_policy, ssh_auth_logger, get_logger): + return exec_helpers.SSHClient(host=host, port=port, auth=exec_helpers.SSHAuth(username=username, password=password)) + + +@pytest.fixture +def exec_result(): + return exec_helpers.ExecResult(cmd=command, stdin=None, stdout=stdout_src, stderr=stderr_src, exit_code=0) + + +def teardown_function(function): + """Clean-up after tests.""" + with mock.patch("warnings.warn"): + exec_helpers.SSHClient._clear_cache() + + +def test_001_execute_async_sudo(ssh, ssh_transport_channel): + ssh.sudo_mode = True + + ssh.execute_async(command) + ssh_transport_channel.assert_has_calls( + ( + mock.call.makefile_stderr("rb"), + mock.call.exec_command("sudo -S bash -c '" 'eval "$(base64 -d <(echo "{0}"))"\''.format(encoded_cmd)), + ) + ) + + +def test_002_execute_async_with_sudo_enforce(ssh, ssh_transport_channel): + assert ssh.sudo_mode is False + + with ssh.sudo(enforce=True): + ssh.execute_async(command) + ssh_transport_channel.assert_has_calls( + ( + mock.call.makefile_stderr("rb"), + mock.call.exec_command("sudo -S bash -c '" 'eval "$(base64 -d <(echo "{0}"))"\''.format(encoded_cmd)), + ) + ) + + +def test_003_execute_async_with_no_sudo_enforce(ssh, ssh_transport_channel): + ssh.sudo_mode = True + + with ssh.sudo(enforce=False): + ssh.execute_async(command) + ssh_transport_channel.assert_has_calls( + (mock.call.makefile_stderr("rb"), mock.call.exec_command("{}\n".format(command))) + ) + + +def test_004_execute_async_with_sudo_none_enforce(ssh, ssh_transport_channel): + ssh.sudo_mode = False + + with ssh.sudo(): + ssh.execute_async(command) + ssh_transport_channel.assert_has_calls( + (mock.call.makefile_stderr("rb"), mock.call.exec_command("{}\n".format(command))) + ) + + +def test_005_execute_async_sudo_password(ssh, ssh_transport_channel, mocker): + enter_password = mocker.patch("exec_helpers.ssh_auth.SSHAuth.enter_password") + + ssh.sudo_mode = True + + res = ssh.execute_async(command) + ssh_transport_channel.assert_has_calls( + ( + mock.call.makefile_stderr("rb"), + mock.call.exec_command("sudo -S bash -c '" 'eval "$(base64 -d <(echo "{0}"))"\''.format(encoded_cmd)), + ) + ) + + enter_password.assert_called_once_with(res.stdin) + + +def test_006_keepalive(ssh, paramiko_ssh_client): + with ssh: + pass + + paramiko_ssh_client().close.assert_not_called() + + +def test_007_no_keepalive(ssh, paramiko_ssh_client): + ssh.keepalive_mode = False + + with ssh: + pass + + paramiko_ssh_client().close.assert_called_once() + + +def test_008_keepalive_enforced(ssh, paramiko_ssh_client): + ssh.keepalive_mode = False + + with ssh.keepalive(): + pass + + paramiko_ssh_client().close.assert_not_called() + + +def test_009_no_keepalive_enforced(ssh, paramiko_ssh_client): + assert ssh.keepalive_mode is True + + with ssh.keepalive(enforce=False): + pass + + paramiko_ssh_client().close.assert_called_once() + + +def test_010_check_stdin_closed(paramiko_ssh_client, chan_makefile, auto_add_policy, get_logger): + chan = mock.Mock(makefile=chan_makefile, closed=True) + chan_makefile.channel = chan + chan.attach_mock(mock.Mock(return_value=FakeFileStream(*stderr_src)), "makefile_stderr") + chan.configure_mock(exit_status=0) + chan.status_event.attach_mock(mock.Mock(return_value=True), "is_set") + open_session = mock.Mock(return_value=chan) + transport = mock.Mock() + transport.attach_mock(open_session, "open_session") + get_transport = mock.Mock(return_value=transport) + _ssh = mock.Mock() + _ssh.attach_mock(get_transport, "get_transport") + paramiko_ssh_client.return_value = _ssh + + stdin_val = "this is a line" + + ssh = exec_helpers.SSHClient(host=host, port=port, auth=exec_helpers.SSHAuth(username=username, password=password)) + ssh.execute_async(command=print_stdin, stdin=stdin_val) + + log = get_logger(ssh.__class__.__name__).getChild("{host}:{port}".format(host=host, port=port)) + log.warning.assert_called_once_with("STDIN Send failed: closed channel") diff --git a/test/test_ssh_client_execute_special.py b/test/test_ssh_client_execute_special.py new file mode 100644 index 0000000..5fa80e1 --- /dev/null +++ b/test/test_ssh_client_execute_special.py @@ -0,0 +1,217 @@ +# Copyright 2018 Alexey Stepanov aka penguinolog. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import base64 +import logging +import threading +import typing + +import mock +import pytest + +import exec_helpers + + +class FakeFileStream: + """Mock-like object for stream emulation.""" + + def __init__(self, *args): + self.__src = list(args) + self.closed = False + self.channel = None + + def __iter__(self): + """Normally we iter over source.""" + for _ in range(len(self.__src)): + yield self.__src.pop(0) + + def fileno(self): + return hash(tuple(self.__src)) + + def close(self): + """We enforce close.""" + self.closed = True + + +def read_stream(stream: FakeFileStream) -> typing.Tuple[bytes, ...]: + return tuple([line for line in stream]) + + +host = "127.0.0.1" +host2 = "127.0.0.2" +port = 22 +username = "user" +password = "pass" + +command = "ls ~\nline 2\nline 3\nline с кирилицей" +command_log = "Executing command:\n{!r}\n".format(command.rstrip()) +stdout_src = (b" \n", b"2\n", b"3\n", b" \n") +stderr_src = (b" \n", b"0\n", b"1\n", b" \n") +encoded_cmd = base64.b64encode("{}\n".format(command).encode("utf-8")).decode("utf-8") + +print_stdin = 'read line; echo "$line"' +default_timeout = 60 * 60 # 1 hour + + +@pytest.fixture +def auto_add_policy(mocker): + return mocker.patch("paramiko.AutoAddPolicy", return_value="AutoAddPolicy") + + +@pytest.fixture +def paramiko_ssh_client(mocker): + mocker.patch("time.sleep") + return mocker.patch("paramiko.SSHClient") + + +@pytest.fixture +def chan_makefile(): + class MkFile: + def __init__(self): + self.stdin = None + self.stdout = None + self.channel = None + + def __call__(self, flags: str): + if "wb" == flags: + self.stdin = mock.Mock() + self.stdin.channel = self.channel + return self.stdin + elif "rb" == flags: + self.stdout = FakeFileStream(*stdout_src) + self.stdout.channel = self.channel + return self.stdout + raise ValueError("Unexpected flags: {!r}".format(flags)) + + return MkFile() + + +@pytest.fixture +def ssh_transport_channel(paramiko_ssh_client, chan_makefile): + chan = mock.Mock(makefile=chan_makefile, closed=False) + chan_makefile.channel = chan + chan.attach_mock(mock.Mock(return_value=FakeFileStream(*stderr_src)), "makefile_stderr") + chan.configure_mock(exit_status=0) + chan.status_event.attach_mock(mock.Mock(return_value=True), "is_set") + open_session = mock.Mock(return_value=chan) + transport = mock.Mock() + transport.attach_mock(open_session, "open_session") + get_transport = mock.Mock(return_value=transport) + _ssh = mock.Mock() + _ssh.attach_mock(get_transport, "get_transport") + paramiko_ssh_client.return_value = _ssh + return chan + + +@pytest.fixture +def ssh_auth_logger(mocker): + return mocker.patch("exec_helpers.ssh_auth.logger") + + +@pytest.fixture +def get_logger(mocker): + return mocker.patch("logging.getLogger") + + +@pytest.fixture +def ssh(paramiko_ssh_client, ssh_transport_channel, auto_add_policy, ssh_auth_logger, get_logger): + return exec_helpers.SSHClient(host=host, port=port, auth=exec_helpers.SSHAuth(username=username, password=password)) + + +@pytest.fixture +def ssh2(paramiko_ssh_client, ssh_transport_channel, auto_add_policy, ssh_auth_logger, get_logger): + return exec_helpers.SSHClient( + host=host2, port=port, auth=exec_helpers.SSHAuth(username=username, password=password) + ) + + +@pytest.fixture +def exec_result(): + return exec_helpers.ExecResult(cmd=command, stdin=None, stdout=stdout_src, stderr=stderr_src, exit_code=0) + + +def teardown_function(function): + """Clean-up after tests.""" + with mock.patch("warnings.warn"): + exec_helpers.SSHClient._clear_cache() + + +def test_001_mask_command(ssh, get_logger) -> None: + cmd = "USE='secret=secret_pass' do task" + log_mask_re = r"secret\s*=\s*([A-Z-a-z0-9_\-]+)" + masked_cmd = "USE='secret=<*masked*>' do task" + cmd_log = "Executing command:\n{!r}\n".format(masked_cmd) + done_log = "Command {!r} exit code: {!s}".format(masked_cmd, exec_helpers.ExitCodes.EX_OK) + + log = get_logger(ssh.__class__.__name__).getChild("{host}:{port}".format(host=host, port=port)) + res = ssh.execute(cmd, log_mask_re=log_mask_re) + assert res.cmd == masked_cmd + assert log.mock_calls[0] == mock.call.log(level=logging.DEBUG, msg=cmd_log) + assert log.mock_calls[-1] == mock.call.log(level=logging.DEBUG, msg=done_log) + + +def test_002_mask_command_global(ssh, get_logger) -> None: + cmd = "USE='secret=secret_pass' do task" + log_mask_re = r"secret\s*=\s*([A-Z-a-z0-9_\-]+)" + masked_cmd = "USE='secret=<*masked*>' do task" + cmd_log = "Executing command:\n{!r}\n".format(masked_cmd) + done_log = "Command {!r} exit code: {!s}".format(masked_cmd, exec_helpers.ExitCodes.EX_OK) + + log = get_logger(ssh.__class__.__name__).getChild("{host}:{port}".format(host=host, port=port)) + + ssh.log_mask_re = log_mask_re + res = ssh.execute(cmd) + assert res.cmd == masked_cmd + assert log.mock_calls[0] == mock.call.log(level=logging.DEBUG, msg=cmd_log) + assert log.mock_calls[-1] == mock.call.log(level=logging.DEBUG, msg=done_log) + + +def test_003_execute_verbose(ssh, get_logger) -> None: + cmd_log = "Executing command:\n{!r}\n".format(command) + done_log = "Command {!r} exit code: {!s}".format(command, exec_helpers.ExitCodes.EX_OK) + + log = get_logger(ssh.__class__.__name__).getChild("{host}:{port}".format(host=host, port=port)) + ssh.execute(command, verbose=True) + + assert log.mock_calls[0] == mock.call.log(level=logging.INFO, msg=cmd_log) + assert log.mock_calls[-1] == mock.call.log(level=logging.INFO, msg=done_log) + + +def test_004_execute_timeout(ssh) -> None: + """We allow timeout and not crush on it if fit.""" + ssh.execute(command, timeout=0.01) + + +def test_005_execute_timeout_fail(ssh, ssh_transport_channel, exec_result) -> None: + """We allow timeout and not crush on it if fit.""" + ssh_transport_channel.status_event = threading.Event() + with pytest.raises(exec_helpers.ExecHelperTimeoutError) as e: + ssh.execute(command, timeout=0.01) + exc = e.value # type: exec_helpers.ExecHelperTimeoutError + assert exc.timeout == 0.01 + assert exc.cmd == command + assert exc.stdout == exec_result.stdout_str + assert exc.stderr == exec_result.stderr_str + + +def test_006_execute_together_exceptions(ssh, ssh2, mocker) -> None: + mocker.patch("exec_helpers.ssh_client.SSHClient.execute_async", side_effect=RuntimeError) + remotes = [ssh, ssh2] + + with pytest.raises(exec_helpers.ParallelCallExceptions) as e: + ssh.execute_together(remotes=remotes, command=command) + exc = e.value # type: exec_helpers.ParallelCallExceptions + assert list(sorted(exc.exceptions)) == [(host, port), (host2, port)] + for exception in exc.exceptions.values(): + assert isinstance(exception, RuntimeError) diff --git a/test/test_ssh_client_execute_throw_host.py b/test/test_ssh_client_execute_throw_host.py new file mode 100644 index 0000000..62b59cb --- /dev/null +++ b/test/test_ssh_client_execute_throw_host.py @@ -0,0 +1,183 @@ +# Copyright 2018 Alexey Stepanov aka penguinolog. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import base64 +import typing + +import mock +import pytest + +import exec_helpers + + +class FakeFileStream: + """Mock-like object for stream emulation.""" + + def __init__(self, *args): + self.__src = list(args) + self.closed = False + self.channel = None + + def __iter__(self): + """Normally we iter over source.""" + for _ in range(len(self.__src)): + yield self.__src.pop(0) + + def fileno(self): + return hash(tuple(self.__src)) + + def close(self): + """We enforce close.""" + self.closed = True + + +def read_stream(stream: FakeFileStream) -> typing.Tuple[bytes, ...]: + return tuple([line for line in stream]) + + +host = "127.0.0.1" +port = 22 +username = "user" +password = "pass" + +command = "ls ~\nline 2\nline 3\nline с кирилицей" +command_log = "Executing command:\n{!r}\n".format(command.rstrip()) +stdout_src = (b" \n", b"2\n", b"3\n", b" \n") +stderr_src = (b" \n", b"0\n", b"1\n", b" \n") +encoded_cmd = base64.b64encode("{}\n".format(command).encode("utf-8")).decode("utf-8") + +print_stdin = 'read line; echo "$line"' +default_timeout = 60 * 60 # 1 hour + + +@pytest.fixture +def auto_add_policy(mocker): + return mocker.patch("paramiko.AutoAddPolicy", return_value="AutoAddPolicy") + + +@pytest.fixture +def paramiko_ssh_client(mocker): + mocker.patch("time.sleep") + return mocker.patch("paramiko.SSHClient") + + +@pytest.fixture +def chan_makefile(): + class MkFile: + def __init__(self): + self.stdin = None + self.stdout = None + self.channel = None + + def __call__(self, flags: str): + if "wb" == flags: + self.stdin = mock.Mock() + self.stdin.channel = self.channel + return self.stdin + elif "rb" == flags: + self.stdout = FakeFileStream(*stdout_src) + self.stdout.channel = self.channel + return self.stdout + raise ValueError("Unexpected flags: {!r}".format(flags)) + + return MkFile() + + +@pytest.fixture +def ssh_intermediate_channel(paramiko_ssh_client): + chan = mock.Mock(name="intermediate_channel") + transport = mock.Mock() + transport.attach_mock(chan, "open_channel") + get_transport = mock.Mock(return_value=transport) + _ssh = mock.Mock() + _ssh.attach_mock(get_transport, "get_transport") + paramiko_ssh_client.return_value = _ssh + return chan + + +@pytest.fixture +def ssh_transport(mocker): + transport = mock.Mock(name="transport") + mocker.patch("paramiko.Transport", return_value=transport) + return transport + + +@pytest.fixture +def ssh_transport_channel(chan_makefile, ssh_transport): + chan = mock.Mock(makefile=chan_makefile, closed=False) + chan_makefile.channel = chan + chan.attach_mock(mock.Mock(return_value=FakeFileStream(*stderr_src)), "makefile_stderr") + + chan.configure_mock(exit_status=0) + + chan.status_event.attach_mock(mock.Mock(return_value=True), "is_set") + open_session = mock.Mock(return_value=chan) + ssh_transport.attach_mock(open_session, "open_session") + return chan + + +@pytest.fixture +def ssh_auth_logger(mocker): + return mocker.patch("exec_helpers.ssh_auth.logger") + + +@pytest.fixture +def get_logger(mocker): + return mocker.patch("logging.getLogger") + + +@pytest.fixture +def ssh( + paramiko_ssh_client, ssh_intermediate_channel, ssh_transport_channel, auto_add_policy, ssh_auth_logger, get_logger +): + return exec_helpers.SSHClient(host=host, port=port, auth=exec_helpers.SSHAuth(username=username, password=password)) + + +@pytest.fixture +def exec_result(): + return exec_helpers.ExecResult(cmd=command, stdin=None, stdout=stdout_src, stderr=stderr_src, exit_code=0) + + +def teardown_function(function): + """Clean-up after tests.""" + with mock.patch("warnings.warn"): + exec_helpers.SSHClient._clear_cache() + + +def test_01_execute_through_host_no_creds(ssh, ssh_transport, exec_result) -> None: + target = "127.0.0.2" + result = ssh.execute_through_host(target, command) + ssh_transport.assert_has_calls( + (mock.call.connect(password=password, pkey=None, username=username), mock.call.open_session()) + ) + assert exec_result == result + + +def test_02_execute_through_host_with_creds(ssh, ssh_transport, exec_result) -> None: + target = "127.0.0.2" + username_2 = "user2" + password_2 = "pass2" + result = ssh.execute_through_host( + target, command, auth=exec_helpers.SSHAuth(username=username_2, password=password_2) + ) + ssh_transport.assert_has_calls( + (mock.call.connect(password=password_2, pkey=None, username=username_2), mock.call.open_session()) + ) + assert exec_result == result + + +def test_03_execute_get_pty(ssh, ssh_transport_channel) -> None: + target = "127.0.0.2" + ssh.execute_through_host(target, command, get_pty=True) + ssh_transport_channel.get_pty.assert_called_with(term="vt100", width=80, height=24, width_pixels=0, height_pixels=0) diff --git a/test/test_ssh_client_init_basic.py b/test/test_ssh_client_init_basic.py index 92ae190..0233cb8 100644 --- a/test/test_ssh_client_init_basic.py +++ b/test/test_ssh_client_init_basic.py @@ -112,6 +112,12 @@ def ssh_auth_logger(mocker): return mocker.patch("exec_helpers.ssh_auth.logger") +def teardown_function(function): + """Clean-up after tests.""" + with mock.patch("warnings.warn"): + exec_helpers.SSHClient._clear_cache() + + def test_init_base(paramiko_ssh_client, auto_add_policy, run_parameters, ssh_auth_logger): # Helper code _ssh = mock.call diff --git a/test/test_subprocess.py b/test/test_subprocess.py index c1f9b04..3fd79ea 100644 --- a/test/test_subprocess.py +++ b/test/test_subprocess.py @@ -228,7 +228,12 @@ def test_001_execute_async(popen, logger, run_parameters) -> None: def test_002_execute(popen, logger, exec_result, run_parameters) -> None: runner = exec_helpers.Subprocess() - res = runner.execute(command, stdin=run_parameters["stdin"]) + res = runner.execute( + command, + stdin=run_parameters["stdin"], + open_stdout=run_parameters["open_stdout"], + open_stderr=run_parameters["open_stderr"], + ) assert isinstance(res, exec_helpers.ExecResult) assert res == exec_result popen().wait.assert_called_once_with(timeout=default_timeout)