diff --git a/doc/source/SSHClient.rst b/doc/source/SSHClient.rst index 4875ea9..b976786 100644 --- a/doc/source/SSHClient.rst +++ b/doc/source/SSHClient.rst @@ -69,6 +69,11 @@ API: SSHClient and SSHAuth. ``bool`` Use sudo for all calls, except wrapped in connection.sudo context manager. + .. py:attribute:: keepalive_mode + + ``bool`` + Use keepalive mode for context manager. If `False` - close connection on exit from context manager. + .. py:method:: close() Close connection @@ -93,6 +98,7 @@ API: SSHClient and SSHAuth. .. versionchanged:: 1.0.0 disconnect enforced on close .. versionchanged:: 1.1.0 release lock on exit + .. versionchanged:: 1.2.1 disconnect enforced on close only not in keepalive mode .. py:method:: sudo(enforce=None) @@ -101,6 +107,16 @@ API: SSHClient and SSHAuth. :param enforce: Enforce sudo enabled or disabled. By default: None :type enforce: ``typing.Optional[bool]`` + .. py:method:: keepalive(enforce=None) + + Context manager getter for keepalive operation. + + :param enforce: Enforce keepalive enabled or disabled. By default: True + :type enforce: ``typing.bool`` + + .. Note:: Enter and exit ssh context manager is produced as well. + .. versionadded:: 1.2.1 + .. py:method:: execute_async(command, stdin=None, open_stdout=True, open_stderr=True, verbose=False, log_mask_re=None, **kwargs) Execute command in async mode and return channel with IO objects. diff --git a/exec_helpers/_ssh_client_base.py b/exec_helpers/_ssh_client_base.py index d9a24a7..6fb90e6 100644 --- a/exec_helpers/_ssh_client_base.py +++ b/exec_helpers/_ssh_client_base.py @@ -215,12 +215,19 @@ class SSHClientBase(six.with_metaclass(_MemorizedSSH, _api.ExecHelper)): """SSH Client helper.""" __slots__ = ( - '__hostname', '__port', '__auth', '__ssh', '__sftp', 'sudo_mode', + '__hostname', '__port', '__auth', '__ssh', '__sftp', + '__sudo_mode', '__keepalive_mode', ) class __get_sudo(object): """Context manager for call commands with sudo.""" + __slots__ = ( + '__ssh', + '__sudo_status', + '__enforce', + ) + def __init__( self, ssh, # type: SSHClientBase @@ -243,6 +250,40 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): self.__ssh.sudo_mode = self.__sudo_status + class __get_keepalive(object): + """Context manager for keepalive management.""" + + __slots__ = ( + '__ssh', + '__keepalive_status', + '__enforce', + ) + + def __init__( + self, + ssh, # type: SSHClientBase + enforce=True # type: bool + ): # type: (...) -> None + """Context manager for keepalive management. + + :type ssh: SSHClient + :type enforce: bool + :param enforce: Keep connection alive after context manager exit + """ + self.__ssh = ssh + self.__keepalive_status = ssh.keepalive_mode + self.__enforce = enforce + + def __enter__(self): + self.__keepalive_status = self.__ssh.keepalive_mode + if self.__enforce is not None: + self.__ssh.keepalive_mode = self.__enforce + self.__ssh.__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb): + self.__ssh.__exit__(exc_type=exc_type, exc_val=exc_val, exc_tb=exc_tb) + self.__ssh.keepalive_mode = self.__keepalive_status + def __hash__(self): """Hash for usage as dict keys.""" return hash(( @@ -286,7 +327,9 @@ def __init__( self.__hostname = host self.__port = port - self.sudo_mode = False + self.__sudo_mode = False + self.__keepalive_mode = True + self.__ssh = paramiko.SSHClient() self.__ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) self.__sftp = None @@ -460,10 +503,44 @@ def __exit__(self, exc_type, exc_val, exc_tb): .. versionchanged:: 1.0.0 disconnect enforced on close .. versionchanged:: 1.1.0 release lock on exit + .. versionchanged:: 1.2.1 disconnect enforced on close only not in keepalive mode """ - self.close() + if not self.__keepalive_mode: + self.close() super(SSHClientBase, self).__exit__(exc_type, exc_val, exc_tb) + @property + def sudo_mode(self): # type: () -> bool + """Persistent sudo mode for connection object. + + :rtype: bool + """ + return self.__sudo_mode + + @sudo_mode.setter + def sudo_mode(self, mode): # type: (bool) -> None + """Persistent sudo mode change for connection object. + + :type mode: bool + """ + self.__sudo_mode = bool(mode) + + @property + def keepalive_mode(self): # type: () -> bool + """Persistent keepalive mode for connection object. + + :rtype: bool + """ + return self.__keepalive_mode + + @keepalive_mode.setter + def keepalive_mode(self, mode): # type: (bool) -> None + """Persistent keepalive mode change for connection object. + + :type mode: bool + """ + self.__keepalive_mode = bool(mode) + def reconnect(self): # type: () -> None """Reconnect SSH session.""" with self.lock: @@ -485,6 +562,20 @@ def sudo( """ return self.__get_sudo(ssh=self, enforce=enforce) + def keepalive( + self, + enforce=True # type: bool + ): + """Call contextmanager with keepalive mode change. + + :param enforce: Enforce keepalive enabled or disabled. + :type enforce: bool + + .. Note:: Enter and exit ssh context manager is produced as well. + .. versionadded:: 1.2.1 + """ + return self.__get_keepalive(ssh=self, enforce=enforce) + def execute_async( self, command, # type: str @@ -839,7 +930,8 @@ def get_result( list(futures.values()), timeout=timeout ) # type: typing.Set[concurrent.futures.Future], typing.Set[concurrent.futures.Future] - for future in not_done: + + for future in not_done: # pragma: no cover future.cancel() for ( diff --git a/exec_helpers/subprocess_runner.py b/exec_helpers/subprocess_runner.py index 1059707..24fa873 100644 --- a/exec_helpers/subprocess_runner.py +++ b/exec_helpers/subprocess_runner.py @@ -21,6 +21,8 @@ from __future__ import unicode_literals import collections +# noinspection PyCompatibility +import concurrent.futures import errno import logging import os @@ -203,7 +205,7 @@ def poll_streams( verbose=verbose ) - @threaded.threaded(started=True) + @threaded.threadpooled() def poll_pipes( result, # type: exec_result.ExecResult stop, # type: threading.Event @@ -245,17 +247,17 @@ def poll_pipes( stop_event = threading.Event() # pylint: disable=assignment-from-no-return - poll_thread = poll_pipes( + future = poll_pipes( result, stop_event - ) # type: threading.Thread + ) # type: concurrent.futures.Future # pylint: enable=assignment-from-no-return # wait for process close - stop_event.wait(timeout) + + concurrent.futures.wait([future], timeout) # Process closed? if stop_event.is_set(): - poll_thread.join(0.1) stop_event.clear() return result # Kill not ended process and wait for close @@ -264,7 +266,7 @@ def poll_pipes( stop_event.wait(5) # Force stop cycle if no exit code after kill stop_event.set() - poll_thread.join(5) + future.cancel() except OSError: # Nothing to kill logger.warning( diff --git a/test/test_ssh_client.py b/test/test_ssh_client.py index d466a2f..65064cb 100644 --- a/test/test_ssh_client.py +++ b/test/test_ssh_client.py @@ -92,7 +92,7 @@ def gen_cmd_result_log_message(result): return (u"Command exit code '{code!s}':\n{cmd!s}\n" .format(cmd=result.cmd.rstrip(), code=result.exit_code)) - def test_execute_async(self, client, policy, logger): + def test_001_execute_async(self, client, policy, logger): chan = mock.Mock() open_session = mock.Mock(return_value=chan) transport = mock.Mock() @@ -122,7 +122,7 @@ def test_execute_async(self, client, policy, logger): log.mock_calls ) - def test_execute_async_pty(self, client, policy, logger): + def test_002_execute_async_pty(self, client, policy, logger): chan = mock.Mock() open_session = mock.Mock(return_value=chan) transport = mock.Mock() @@ -157,7 +157,7 @@ def test_execute_async_pty(self, client, policy, logger): log.mock_calls ) - def test_execute_async_no_stdout_stderr(self, client, policy, logger): + def test_003_execute_async_no_stdout_stderr(self, client, policy, logger): chan = mock.Mock() open_session = mock.Mock(return_value=chan) transport = mock.Mock() @@ -208,7 +208,7 @@ def test_execute_async_no_stdout_stderr(self, client, policy, logger): mock.call.exec_command('{}\n'.format(command)) )) - def test_execute_async_sudo(self, client, policy, logger): + def test_004_execute_async_sudo(self, client, policy, logger): chan = mock.Mock() open_session = mock.Mock(return_value=chan) transport = mock.Mock() @@ -241,7 +241,7 @@ def test_execute_async_sudo(self, client, policy, logger): log.mock_calls ) - def test_execute_async_with_sudo_enforce(self, client, policy, logger): + def test_005_execute_async_with_sudo_enforce(self, client, policy, logger): chan = mock.Mock() open_session = mock.Mock(return_value=chan) transport = mock.Mock() @@ -277,7 +277,7 @@ def test_execute_async_with_sudo_enforce(self, client, policy, logger): log.mock_calls ) - def test_execute_async_with_no_sudo_enforce(self, client, policy, logger): + def test_006_execute_async_with_no_sudo_enforce(self, client, policy, logger): chan = mock.Mock() open_session = mock.Mock(return_value=chan) transport = mock.Mock() @@ -309,7 +309,7 @@ def test_execute_async_with_no_sudo_enforce(self, client, policy, logger): log.mock_calls ) - def test_execute_async_with_none_enforce(self, client, policy, logger): + def test_007_execute_async_with_sudo_none_enforce(self, client, policy, logger): chan = mock.Mock() open_session = mock.Mock(return_value=chan) transport = mock.Mock() @@ -342,7 +342,7 @@ def test_execute_async_with_none_enforce(self, client, policy, logger): ) @mock.patch('exec_helpers.ssh_auth.SSHAuth.enter_password') - def test_execute_async_sudo_password( + def test_008_execute_async_sudo_password( self, enter_password, client, policy, logger): stdin = mock.Mock(name='stdin') stdout = mock.Mock(name='stdout') @@ -386,7 +386,7 @@ def test_execute_async_sudo_password( log.mock_calls ) - def test_execute_async_verbose(self, client, policy, logger): + def test_009_execute_async_verbose(self, client, policy, logger): chan = mock.Mock() open_session = mock.Mock(return_value=chan) transport = mock.Mock() @@ -416,7 +416,7 @@ def test_execute_async_verbose(self, client, policy, logger): log.mock_calls ) - def test_execute_async_mask_command(self, client, policy, logger): + def test_010_execute_async_mask_command(self, client, policy, logger): cmd = "USE='secret=secret_pass' do task" log_mask_re = r"secret\s*=\s*([A-Z-a-z0-9_\-]+)" masked_cmd = "USE='secret=<*masked*>' do task" @@ -451,7 +451,7 @@ def test_execute_async_mask_command(self, client, policy, logger): log.mock_calls ) - def test_check_stdin_str(self, client, policy, logger): + def test_011_check_stdin_str(self, client, policy, logger): stdin_val = u'this is a line' stdin = mock.Mock(name='stdin') @@ -496,7 +496,7 @@ def test_check_stdin_str(self, client, policy, logger): mock.call.exec_command('{val}\n'.format(val=print_stdin)) )) - def test_check_stdin_bytes(self, client, policy, logger): + def test_012_check_stdin_bytes(self, client, policy, logger): stdin_val = b'this is a line' stdin = mock.Mock(name='stdin') @@ -541,7 +541,7 @@ def test_check_stdin_bytes(self, client, policy, logger): mock.call.exec_command('{val}\n'.format(val=print_stdin)) )) - def test_check_stdin_bytearray(self, client, policy, logger): + def test_013_check_stdin_bytearray(self, client, policy, logger): stdin_val = bytearray(b'this is a line') stdin = mock.Mock(name='stdin') @@ -586,7 +586,7 @@ def test_check_stdin_bytearray(self, client, policy, logger): mock.call.exec_command('{val}\n'.format(val=print_stdin)) )) - def test_check_stdin_closed(self, client, policy, logger): + def test_014_check_stdin_closed(self, client, policy, logger): stdin_val = 'this is a line' stdin = mock.Mock(name='stdin') @@ -631,6 +631,76 @@ def test_check_stdin_closed(self, client, policy, logger): mock.call.exec_command('{val}\n'.format(val=print_stdin)) )) + def test_015_keepalive(self, client, policy, logger): + chan = mock.Mock() + open_session = mock.Mock(return_value=chan) + transport = mock.Mock() + transport.attach_mock(open_session, 'open_session') + get_transport = mock.Mock(return_value=transport) + _ssh = mock.Mock() + _ssh.attach_mock(get_transport, 'get_transport') + client.return_value = _ssh + + ssh = self.get_ssh() + + with ssh: + pass + + _ssh.close.assert_not_called() + + def test_016_no_keepalive(self, client, policy, logger): + chan = mock.Mock() + open_session = mock.Mock(return_value=chan) + transport = mock.Mock() + transport.attach_mock(open_session, 'open_session') + get_transport = mock.Mock(return_value=transport) + _ssh = mock.Mock() + _ssh.attach_mock(get_transport, 'get_transport') + client.return_value = _ssh + + ssh = self.get_ssh() + ssh.keepalive_mode = False + + with ssh: + pass + + _ssh.close.assert_called_once() + + def test_017_keepalive_enforced(self, client, policy, logger): + chan = mock.Mock() + open_session = mock.Mock(return_value=chan) + transport = mock.Mock() + transport.attach_mock(open_session, 'open_session') + get_transport = mock.Mock(return_value=transport) + _ssh = mock.Mock() + _ssh.attach_mock(get_transport, 'get_transport') + client.return_value = _ssh + + ssh = self.get_ssh() + ssh.keepalive_mode = False + + with ssh.keepalive(): + pass + + _ssh.close.assert_not_called() + + def test_018_no_keepalive_enforced(self, client, policy, logger): + chan = mock.Mock() + open_session = mock.Mock(return_value=chan) + transport = mock.Mock() + transport.attach_mock(open_session, 'open_session') + get_transport = mock.Mock(return_value=transport) + _ssh = mock.Mock() + _ssh.attach_mock(get_transport, 'get_transport') + client.return_value = _ssh + + ssh = self.get_ssh() + + with ssh.keepalive(enforce=False): + pass + + _ssh.close.assert_called_once() + @staticmethod def get_patched_execute_async_retval( ec=0, @@ -680,7 +750,7 @@ def get_patched_execute_async_retval( return chan, '', exp_result, stderr, stdout @mock.patch('exec_helpers.ssh_client.SSHClient.execute_async') - def test_execute( + def test_019_execute( self, execute_async, client, policy, logger @@ -727,7 +797,7 @@ def test_execute( ) @mock.patch('exec_helpers.ssh_client.SSHClient.execute_async') - def test_execute_verbose( + def test_020_execute_verbose( self, execute_async, client, policy, logger): @@ -772,7 +842,7 @@ def test_execute_verbose( ) @mock.patch('exec_helpers.ssh_client.SSHClient.execute_async') - def test_execute_no_stdout( + def test_021_execute_no_stdout( self, execute_async, client, policy, logger @@ -816,7 +886,7 @@ def test_execute_no_stdout( ) @mock.patch('exec_helpers.ssh_client.SSHClient.execute_async') - def test_execute_no_stderr( + def test_022_execute_no_stderr( self, execute_async, client, policy, logger @@ -860,7 +930,7 @@ def test_execute_no_stderr( ) @mock.patch('exec_helpers.ssh_client.SSHClient.execute_async') - def test_execute_no_stdout_stderr( + def test_023_execute_no_stdout_stderr( self, execute_async, client, policy, logger @@ -907,7 +977,7 @@ def test_execute_no_stdout_stderr( @mock.patch('time.sleep', autospec=True) @mock.patch('exec_helpers.ssh_client.SSHClient.execute_async') - def test_execute_timeout( + def test_024_execute_timeout( self, execute_async, sleep, client, policy, logger): @@ -941,7 +1011,7 @@ def test_execute_timeout( @mock.patch('time.sleep', autospec=True) @mock.patch('exec_helpers.ssh_client.SSHClient.execute_async') - def test_execute_timeout_fail( + def test_025_execute_timeout_fail( self, execute_async, sleep, client, policy, logger): @@ -966,7 +1036,7 @@ def test_execute_timeout_fail( chan.assert_has_calls((mock.call.status_event.is_set(), )) @mock.patch('exec_helpers.ssh_client.SSHClient.execute_async') - def test_execute_mask_command( + def test_026_execute_mask_command( self, execute_async, client, policy, logger @@ -1019,7 +1089,7 @@ def test_execute_mask_command( ) @mock.patch('exec_helpers.ssh_client.SSHClient.execute_async') - def test_execute_together(self, execute_async, client, policy, logger): + def test_027_execute_together(self, execute_async, client, policy, logger): ( chan, _stdin, _, stderr, stdout ) = self.get_patched_execute_async_retval() @@ -1070,7 +1140,7 @@ def test_execute_together(self, execute_async, client, policy, logger): remotes=remotes, command=command, expected=[1]) @mock.patch('exec_helpers.ssh_client.SSHClient.execute_async') - def test_execute_together_exceptions( + def test_028_execute_together_exceptions( self, execute_async, # type: mock.Mock client, @@ -1108,7 +1178,7 @@ def test_execute_together_exceptions( self.assertIsInstance(exception, RuntimeError) @mock.patch('exec_helpers.ssh_client.SSHClient.execute') - def test_check_call(self, execute, client, policy, logger): + def test_029_check_call(self, execute, client, policy, logger): exit_code = 0 return_value = exec_result.ExecResult( cmd=command, @@ -1147,7 +1217,7 @@ def test_check_call(self, execute, client, policy, logger): execute.assert_called_once_with(command, verbose, None) @mock.patch('exec_helpers.ssh_client.SSHClient.execute') - def test_check_call_expected(self, execute, client, policy, logger): + def test_030_check_call_expected(self, execute, client, policy, logger): exit_code = 0 return_value = exec_result.ExecResult( cmd=command, @@ -1185,7 +1255,7 @@ def test_check_call_expected(self, execute, client, policy, logger): execute.assert_called_once_with(command, verbose, None) @mock.patch('exec_helpers.ssh_client.SSHClient.check_call') - def test_check_stderr(self, check_call, client, policy, logger): + def test_031_check_stderr(self, check_call, client, policy, logger): return_value = exec_result.ExecResult( cmd=command, stdout=stdout_list,