diff --git a/exec_helpers/_ssh_client_base.py b/exec_helpers/_ssh_client_base.py index 6fb90e6..50254f1 100644 --- a/exec_helpers/_ssh_client_base.py +++ b/exec_helpers/_ssh_client_base.py @@ -683,9 +683,7 @@ def _exec_command( .. versionchanged:: 1.2.0 log_mask_re regex rule for masking cmd """ - def poll_streams( - result, # type: exec_result.ExecResult - ): + def poll_streams(): """Poll FIFO buffers if data available.""" if stdout and interface.recv_ready(): result.read_stdout( @@ -701,22 +699,15 @@ def poll_streams( ) @threaded.threadpooled - def poll_pipes( - result, # type: exec_result.ExecResult - stop, # type: threading.Event - ): + def poll_pipes(stop, ): # type: (threading.Event) -> None """Polling task for FIFO buffers. - :type stdout: paramiko.channel.ChannelFile - :type stderr: paramiko.channel.ChannelFile - :type result: ExecResult :type stop: Event - :type channel: paramiko.channel.Channel """ while not stop.is_set(): time.sleep(0.1) if stdout or stderr: - poll_streams(result=result) + poll_streams() if interface.status_event.is_set(): result.read_stdout( @@ -744,10 +735,7 @@ def poll_pipes( stop_event = threading.Event() # pylint: disable=assignment-from-no-return - future = poll_pipes( - result=result, - stop=stop_event, - ) # type: concurrent.futures.Future + future = poll_pipes(stop=stop_event) # type: concurrent.futures.Future # pylint: enable=assignment-from-no-return concurrent.futures.wait([future], timeout) @@ -880,9 +868,7 @@ def execute_together( .. versionchanged:: 1.2.0 log_mask_re regex rule for masking cmd """ @threaded.threadpooled - def get_result( - remote # type: SSHClientBase - ): # type: (...) -> exec_result.ExecResult + def get_result(): # type: () -> exec_result.ExecResult """Get result from remote call.""" ( chan, @@ -921,7 +907,7 @@ def get_result( raised_exceptions = {} for remote in set(remotes): # Use distinct remotes - futures[remote] = get_result(remote) + futures[remote] = get_result() ( _, diff --git a/exec_helpers/ssh_auth.py b/exec_helpers/ssh_auth.py index 6d751eb..7ab9f98 100644 --- a/exec_helpers/ssh_auth.py +++ b/exec_helpers/ssh_auth.py @@ -23,7 +23,7 @@ import copy import io # noqa # pylint: disable=unused-import import logging -import typing +import typing # noqa # pylint: disable=unused-import import paramiko @@ -33,11 +33,6 @@ logging.getLogger('paramiko').setLevel(logging.WARNING) logging.getLogger('iso8601').setLevel(logging.WARNING) -_type_ConnectSSH = typing.Union[ - paramiko.client.SSHClient, paramiko.transport.Transport -] -_type_RSAKeys = typing.Iterable[paramiko.RSAKey] - class SSHAuth(object): """SSH Authorization object.""" @@ -141,7 +136,7 @@ def enter_password(self, tgt): # type: (io.StringIO) -> None def connect( self, - client, # type: _type_ConnectSSH + client, # type: typing.Union[paramiko.SSHClient, paramiko.Transport] hostname=None, # type: typing.Optional[str] port=22, # type: int log=True, # type: bool @@ -149,10 +144,7 @@ def connect( """Connect SSH client object using credentials. :param client: SSH Client (low level) - :type client: typing.Union[ - paramiko.client.SSHClient, - paramiko.transport.Transport, - ] + :type client: typing.Union[paramiko.SSHClient, paramiko.Transport] :param hostname: remote hostname :type hostname: str :param port: remote ssh port @@ -164,13 +156,19 @@ def connect( kwargs = { 'username': self.username, 'password': self.__password, - 'key_filename': self.key_filename, - 'passphrase': self.__passphrase, } # type: typing.Dict[str, typing.Any] if hostname is not None: kwargs['hostname'] = hostname kwargs['port'] = port + if isinstance(client, paramiko.client.SSHClient): # pragma: no cover + # paramiko.transport.Transport still do not allow passphrase and key filename + + if self.key_filename is not None: + kwargs['key_filename'] = self.key_filename + if self.__passphrase is not None: + kwargs['passphrase'] = self.__passphrase + keys = [self.__key] keys.extend([k for k in self.__keys if k != self.__key]) diff --git a/exec_helpers/subprocess_runner.py b/exec_helpers/subprocess_runner.py index 24fa873..38c364f 100644 --- a/exec_helpers/subprocess_runner.py +++ b/exec_helpers/subprocess_runner.py @@ -178,9 +178,7 @@ def _exec_command( .. versionadded:: 1.2.0 """ - def poll_streams( - result, # type: exec_result.ExecResult - ): + def poll_streams(): """Poll streams to the result object.""" if _win: # pragma: no cover # select.select is not supported on windows @@ -206,19 +204,15 @@ def poll_streams( ) @threaded.threadpooled() - def poll_pipes( - result, # type: exec_result.ExecResult - stop, # type: threading.Event - ): + def poll_pipes(stop, ): # type: (threading.Event) -> None """Polling task for FIFO buffers. - :type result: ExecResult :type stop: Event """ while not stop.is_set(): time.sleep(0.1) if stdout or stderr: - poll_streams(result=result) + poll_streams() interface.poll() @@ -247,10 +241,7 @@ def poll_pipes( stop_event = threading.Event() # pylint: disable=assignment-from-no-return - future = poll_pipes( - result, - stop_event - ) # type: concurrent.futures.Future + future = poll_pipes(stop_event) # type: concurrent.futures.Future # pylint: enable=assignment-from-no-return # wait for process close diff --git a/test/test_ssh_client.py b/test/test_ssh_client.py index 65064cb..9af883f 100644 --- a/test/test_ssh_client.py +++ b/test/test_ssh_client.py @@ -1314,7 +1314,10 @@ def prepare_execute_through_host(transp, client, exit_code): return_value=intermediate_channel, name='open_channel' ) - intermediate_transport = mock.Mock(name='intermediate_transport') + intermediate_transport = mock.Mock( + name='intermediate_transport', + spec='paramiko.transport.Transport' + ) intermediate_transport.attach_mock(open_channel, 'open_channel') get_transport = mock.Mock( return_value=intermediate_transport, @@ -1356,8 +1359,7 @@ def prepare_execute_through_host(transp, client, exit_code): open_channel, intermediate_channel ) - def test_execute_through_host_no_creds( - self, transp, client, policy, logger): + def test_01_execute_through_host_no_creds(self, transp, client, policy, logger): target = '127.0.0.2' exit_code = 0 @@ -1398,10 +1400,7 @@ def test_execute_through_host_no_creds( transp.assert_called_once_with(intermediate_channel) open_session.assert_called_once() transport.assert_has_calls(( - mock.call.connect( - username=username, password=password, pkey=None, - key_filename=None, passphrase=None, - ), + mock.call.connect(username=username, password=password, pkey=None), mock.call.open_session() )) channel.assert_has_calls(( @@ -1414,8 +1413,7 @@ def test_execute_through_host_no_creds( mock.call.close() )) - def test_execute_through_host_auth( - self, transp, client, policy, logger): + def test_02_execute_through_host_auth(self, transp, client, policy, logger): _login = 'cirros' _password = 'cubswin:)' @@ -1442,7 +1440,9 @@ def test_execute_through_host_auth( port=port, auth=exec_helpers.SSHAuth( username=username, - password=password + password=password, + key_filename='~/fake_key', + passphrase='fake_passphrase' )) # noinspection PyTypeChecker @@ -1455,10 +1455,7 @@ def test_execute_through_host_auth( transp.assert_called_once_with(intermediate_channel) open_session.assert_called_once() transport.assert_has_calls(( - mock.call.connect( - username=_login, password=_password, pkey=None, - key_filename=None, passphrase=None, - ), + mock.call.connect(username=_login, password=_password, pkey=None), mock.call.open_session() )) channel.assert_has_calls(( @@ -1473,8 +1470,7 @@ def test_execute_through_host_auth( @mock.patch('exec_helpers._ssh_client_base.logger', autospec=True) -@mock.patch( - 'paramiko.AutoAddPolicy', autospec=True, return_value='AutoAddPolicy') +@mock.patch('paramiko.AutoAddPolicy', autospec=True, return_value='AutoAddPolicy') @mock.patch('paramiko.SSHClient', autospec=True) class TestSftp(unittest.TestCase): def tearDown(self): diff --git a/test/test_ssh_client_init.py b/test/test_ssh_client_init.py index feffdc8..0c77e48 100644 --- a/test/test_ssh_client_init.py +++ b/test/test_ssh_client_init.py @@ -64,8 +64,7 @@ def __iter__(self): # noinspection PyTypeChecker @mock.patch('exec_helpers.ssh_auth.logger', autospec=True) -@mock.patch( - 'paramiko.AutoAddPolicy', autospec=True, return_value='AutoAddPolicy') +@mock.patch('paramiko.AutoAddPolicy', autospec=True, return_value='AutoAddPolicy') @mock.patch('paramiko.SSHClient', autospec=True) class TestSSHClientInit(unittest.TestCase): def tearDown(self): @@ -121,33 +120,43 @@ def init_checks( if auth is None: if private_keys is None or len(private_keys) == 0: pkey = None + + kwargs = dict( + hostname=host, password=password, + pkey=pkey, + port=port, username=username, + ) + if key_filename: + kwargs['key_filename'] = key_filename + if passphrase: + kwargs['passphrase'] = passphrase + expected_calls = [ _ssh, _ssh.set_missing_host_key_policy('AutoAddPolicy'), - _ssh.connect( - hostname=host, password=password, - pkey=pkey, - port=port, username=username, - key_filename=key_filename, passphrase=passphrase - ), + _ssh.connect(**kwargs), ] else: pkey = private_keys[0] + + kwargs = dict( + hostname=host, password=password, + pkey=None, + port=port, username=username, + ) + if key_filename: + kwargs['key_filename'] = key_filename + if passphrase: + kwargs['passphrase'] = passphrase + + kwargs1 = {key: kwargs[key] for key in kwargs} + kwargs1['pkey'] = pkey + expected_calls = [ _ssh, _ssh.set_missing_host_key_policy('AutoAddPolicy'), - _ssh.connect( - hostname=host, password=password, - pkey=None, - port=port, username=username, - key_filename=key_filename, passphrase=passphrase - ), - _ssh.connect( - hostname=host, password=password, - pkey=pkey, - port=port, username=username, - key_filename=key_filename, passphrase=passphrase - ), + _ssh.connect(**kwargs), + _ssh.connect(**kwargs1), ] self.assertIn(expected_calls, client.mock_calls) @@ -180,13 +189,13 @@ def init_checks( ) ) - def test_init_host(self, client, policy, logger): + def test_001_init_host(self, client, policy, logger): """Test with host only set""" self.init_checks( client, policy, logger, host=host) - def test_init_alternate_port(self, client, policy, logger): + def test_002_init_alternate_port(self, client, policy, logger): """Test with alternate port""" self.init_checks( client, policy, logger, @@ -194,7 +203,7 @@ def test_init_alternate_port(self, client, policy, logger): port=2222 ) - def test_init_username(self, client, policy, logger): + def test_003_init_username(self, client, policy, logger): """Test with username only set from creds""" self.init_checks( client, policy, logger, @@ -202,7 +211,7 @@ def test_init_username(self, client, policy, logger): username=username ) - def test_init_username_password(self, client, policy, logger): + def test_004_init_username_password(self, client, policy, logger): """Test with username and password set from creds""" self.init_checks( client, policy, logger, @@ -211,7 +220,7 @@ def test_init_username_password(self, client, policy, logger): password=password ) - def test_init_username_password_empty_keys(self, client, policy, logger): + def test_005_init_username_password_empty_keys(self, client, policy, logger): """Test with username, password and empty keys set from creds""" self.init_checks( client, policy, logger, @@ -221,7 +230,7 @@ def test_init_username_password_empty_keys(self, client, policy, logger): private_keys=[] ) - def test_init_username_single_key(self, client, policy, logger): + def test_006_init_username_single_key(self, client, policy, logger): """Test with username and single key set from creds""" connect = mock.Mock( side_effect=[ @@ -238,7 +247,7 @@ def test_init_username_single_key(self, client, policy, logger): private_keys=gen_private_keys(1), ) - def test_init_username_password_single_key(self, client, policy, logger): + def test_007_init_username_password_single_key(self, client, policy, logger): """Test with username, password and single key set from creds""" connect = mock.Mock( side_effect=[ @@ -256,7 +265,7 @@ def test_init_username_password_single_key(self, client, policy, logger): private_keys=gen_private_keys(1) ) - def test_init_username_multiple_keys(self, client, policy, logger): + def test_008_init_username_multiple_keys(self, client, policy, logger): """Test with username and multiple keys set from creds""" connect = mock.Mock( side_effect=[ @@ -273,7 +282,7 @@ def test_init_username_multiple_keys(self, client, policy, logger): private_keys=gen_private_keys(2) ) - def test_init_username_password_multiple_keys( + def test_009_init_username_password_multiple_keys( self, client, policy, logger): """Test with username, password and multiple keys set from creds""" connect = mock.Mock( @@ -300,7 +309,7 @@ def test_init_username_password_multiple_keys( private_keys=gen_private_keys(2) ) - def test_init_auth(self, client, policy, logger): + def test_010_init_auth(self, client, policy, logger): self.init_checks( client, policy, logger, host=host, @@ -311,7 +320,7 @@ def test_init_auth(self, client, policy, logger): ) ) - def test_init_auth_break(self, client, policy, logger): + def test_011_init_auth_break(self, client, policy, logger): self.init_checks( client, policy, logger, host=host, @@ -325,7 +334,7 @@ def test_init_auth_break(self, client, policy, logger): ) ) - def test_init_context(self, client, policy, logger): + def test_012_init_context(self, client, policy, logger): with exec_helpers.SSHClient( host=host, auth=exec_helpers.SSHAuth() @@ -345,7 +354,7 @@ def test_init_context(self, client, policy, logger): self.assertEqual(ssh.hostname, host) self.assertEqual(ssh.port, port) - def test_init_clear_failed(self, client, policy, logger): + def test_013_init_clear_failed(self, client, policy, logger): """Test reconnect :type client: mock.Mock @@ -407,7 +416,7 @@ def test_init_clear_failed(self, client, policy, logger): mock.call.exception('Could not close sftp connection'), )) - def test_init_reconnect(self, client, policy, logger): + def test_014_init_reconnect(self, client, policy, logger): """Test reconnect :type client: mock.Mock @@ -447,8 +456,6 @@ def test_init_reconnect(self, client, policy, logger): pkey=None, port=22, username=None, - key_filename=None, - passphrase=None ), ] self.assertIn( @@ -469,7 +476,7 @@ def test_init_reconnect(self, client, policy, logger): self.assertEqual(ssh._ssh, client()) @mock.patch('time.sleep', autospec=True) - def test_init_password_required(self, sleep, client, policy, logger): + def test_015_init_password_required(self, sleep, client, policy, logger): connect = mock.Mock(side_effect=paramiko.PasswordRequiredException) _ssh = mock.Mock() _ssh.attach_mock(connect, 'connect') @@ -482,7 +489,7 @@ def test_init_password_required(self, sleep, client, policy, logger): )) @mock.patch('time.sleep', autospec=True) - def test_init_password_broken(self, sleep, client, policy, logger): + def test_016_init_password_broken(self, sleep, client, policy, logger): connect = mock.Mock(side_effect=paramiko.PasswordRequiredException) _ssh = mock.Mock() _ssh.attach_mock(connect, 'connect') @@ -500,7 +507,7 @@ def test_init_password_broken(self, sleep, client, policy, logger): )) @mock.patch('time.sleep', autospec=True) - def test_init_auth_impossible_password( + def test_017_init_auth_impossible_password( self, sleep, client, policy, logger): connect = mock.Mock(side_effect=paramiko.AuthenticationException) @@ -520,7 +527,7 @@ def test_init_auth_impossible_password( ) @mock.patch('time.sleep', autospec=True) - def test_init_auth_impossible_key(self, sleep, client, policy, logger): + def test_018_init_auth_impossible_key(self, sleep, client, policy, logger): connect = mock.Mock(side_effect=paramiko.AuthenticationException) _ssh = mock.Mock() @@ -540,7 +547,7 @@ def test_init_auth_impossible_key(self, sleep, client, policy, logger): ) * 3 ) - def test_init_auth_pass_no_key(self, client, policy, logger): + def test_019_init_auth_pass_no_key(self, client, policy, logger): connect = mock.Mock( side_effect=[ paramiko.AuthenticationException, @@ -584,7 +591,7 @@ def test_init_auth_pass_no_key(self, client, policy, logger): self.assertEqual(ssh._ssh, client()) @mock.patch('time.sleep', autospec=True) - def test_init_auth_brute_impossible(self, sleep, client, policy, logger): + def test_020_init_auth_brute_impossible(self, sleep, client, policy, logger): connect = mock.Mock(side_effect=paramiko.AuthenticationException) _ssh = mock.Mock() @@ -604,7 +611,7 @@ def test_init_auth_brute_impossible(self, sleep, client, policy, logger): ) * 3 ) - def test_init_no_sftp(self, client, policy, logger): + def test_021_init_no_sftp(self, client, policy, logger): open_sftp = mock.Mock(side_effect=paramiko.SSHException) _ssh = mock.Mock() @@ -633,7 +640,7 @@ def test_init_no_sftp(self, client, policy, logger): 'SFTP enable failed! SSH only is accessible.'), )) - def test_init_sftp_repair(self, client, policy, logger): + def test_022_init_sftp_repair(self, client, policy, logger): _sftp = mock.Mock() open_sftp = mock.Mock( side_effect=[ @@ -671,7 +678,7 @@ def test_init_sftp_repair(self, client, policy, logger): )) @mock.patch('exec_helpers.exec_result.ExecResult', autospec=True) - def test_init_memorize( + def test_023_init_memorize( self, Result, client, policy, logger): @@ -723,7 +730,7 @@ def test_init_memorize( 'CPython' != platform.python_implementation(), 'CPython only functionality: close connections depend on refcount' ) - def test_init_memorize_close_unused(self, warn, client, policy, logger): + def test_024_init_memorize_close_unused(self, warn, client, policy, logger): ssh0 = exec_helpers.SSHClient(host=host) del ssh0 # remove reference - now it's cached and unused client.reset_mock() @@ -743,7 +750,7 @@ def test_init_memorize_close_unused(self, warn, client, policy, logger): )) @mock.patch('exec_helpers.ssh_client.SSHClient.execute') - def test_init_memorize_reconnect(self, execute, client, policy, logger): + def test_025_init_memorize_reconnect(self, execute, client, policy, logger): execute.side_effect = paramiko.SSHException exec_helpers.SSHClient(host=host) client.reset_mock()