Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 6 additions & 20 deletions exec_helpers/_ssh_client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,9 +683,7 @@ def _exec_command(

.. versionchanged:: 1.2.0 log_mask_re regex rule for masking cmd
"""
def poll_streams(
result, # type: exec_result.ExecResult
):
def poll_streams():
"""Poll FIFO buffers if data available."""
if stdout and interface.recv_ready():
result.read_stdout(
Expand All @@ -701,22 +699,15 @@ def poll_streams(
)

@threaded.threadpooled
def poll_pipes(
result, # type: exec_result.ExecResult
stop, # type: threading.Event
):
def poll_pipes(stop, ): # type: (threading.Event) -> None
"""Polling task for FIFO buffers.

:type stdout: paramiko.channel.ChannelFile
:type stderr: paramiko.channel.ChannelFile
:type result: ExecResult
:type stop: Event
:type channel: paramiko.channel.Channel
"""
while not stop.is_set():
time.sleep(0.1)
if stdout or stderr:
poll_streams(result=result)
poll_streams()

if interface.status_event.is_set():
result.read_stdout(
Expand Down Expand Up @@ -744,10 +735,7 @@ def poll_pipes(
stop_event = threading.Event()

# pylint: disable=assignment-from-no-return
future = poll_pipes(
result=result,
stop=stop_event,
) # type: concurrent.futures.Future
future = poll_pipes(stop=stop_event) # type: concurrent.futures.Future
# pylint: enable=assignment-from-no-return

concurrent.futures.wait([future], timeout)
Expand Down Expand Up @@ -880,9 +868,7 @@ def execute_together(
.. versionchanged:: 1.2.0 log_mask_re regex rule for masking cmd
"""
@threaded.threadpooled
def get_result(
remote # type: SSHClientBase
): # type: (...) -> exec_result.ExecResult
def get_result(): # type: () -> exec_result.ExecResult
"""Get result from remote call."""
(
chan,
Expand Down Expand Up @@ -921,7 +907,7 @@ def get_result(
raised_exceptions = {}

for remote in set(remotes): # Use distinct remotes
futures[remote] = get_result(remote)
futures[remote] = get_result()

(
_,
Expand Down
24 changes: 11 additions & 13 deletions exec_helpers/ssh_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import copy
import io # noqa # pylint: disable=unused-import
import logging
import typing
import typing # noqa # pylint: disable=unused-import

import paramiko

Expand All @@ -33,11 +33,6 @@
logging.getLogger('paramiko').setLevel(logging.WARNING)
logging.getLogger('iso8601').setLevel(logging.WARNING)

_type_ConnectSSH = typing.Union[
paramiko.client.SSHClient, paramiko.transport.Transport
]
_type_RSAKeys = typing.Iterable[paramiko.RSAKey]


class SSHAuth(object):
"""SSH Authorization object."""
Expand Down Expand Up @@ -141,18 +136,15 @@ def enter_password(self, tgt): # type: (io.StringIO) -> None

def connect(
self,
client, # type: _type_ConnectSSH
client, # type: typing.Union[paramiko.SSHClient, paramiko.Transport]
hostname=None, # type: typing.Optional[str]
port=22, # type: int
log=True, # type: bool
): # type: (...) -> None
"""Connect SSH client object using credentials.

:param client: SSH Client (low level)
:type client: typing.Union[
paramiko.client.SSHClient,
paramiko.transport.Transport,
]
:type client: typing.Union[paramiko.SSHClient, paramiko.Transport]
:param hostname: remote hostname
:type hostname: str
:param port: remote ssh port
Expand All @@ -164,13 +156,19 @@ def connect(
kwargs = {
'username': self.username,
'password': self.__password,
'key_filename': self.key_filename,
'passphrase': self.__passphrase,
} # type: typing.Dict[str, typing.Any]
if hostname is not None:
kwargs['hostname'] = hostname
kwargs['port'] = port

if isinstance(client, paramiko.client.SSHClient): # pragma: no cover
# paramiko.transport.Transport still do not allow passphrase and key filename

if self.key_filename is not None:
kwargs['key_filename'] = self.key_filename
if self.__passphrase is not None:
kwargs['passphrase'] = self.__passphrase

keys = [self.__key]
keys.extend([k for k in self.__keys if k != self.__key])

Expand Down
17 changes: 4 additions & 13 deletions exec_helpers/subprocess_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,7 @@ def _exec_command(

.. versionadded:: 1.2.0
"""
def poll_streams(
result, # type: exec_result.ExecResult
):
def poll_streams():
"""Poll streams to the result object."""
if _win: # pragma: no cover
# select.select is not supported on windows
Expand All @@ -206,19 +204,15 @@ def poll_streams(
)

@threaded.threadpooled()
def poll_pipes(
result, # type: exec_result.ExecResult
stop, # type: threading.Event
):
def poll_pipes(stop, ): # type: (threading.Event) -> None
"""Polling task for FIFO buffers.

:type result: ExecResult
:type stop: Event
"""
while not stop.is_set():
time.sleep(0.1)
if stdout or stderr:
poll_streams(result=result)
poll_streams()

interface.poll()

Expand Down Expand Up @@ -247,10 +241,7 @@ def poll_pipes(
stop_event = threading.Event()

# pylint: disable=assignment-from-no-return
future = poll_pipes(
result,
stop_event
) # type: concurrent.futures.Future
future = poll_pipes(stop_event) # type: concurrent.futures.Future
# pylint: enable=assignment-from-no-return
# wait for process close

Expand Down
28 changes: 12 additions & 16 deletions test/test_ssh_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1314,7 +1314,10 @@ def prepare_execute_through_host(transp, client, exit_code):
return_value=intermediate_channel,
name='open_channel'
)
intermediate_transport = mock.Mock(name='intermediate_transport')
intermediate_transport = mock.Mock(
name='intermediate_transport',
spec='paramiko.transport.Transport'
)
intermediate_transport.attach_mock(open_channel, 'open_channel')
get_transport = mock.Mock(
return_value=intermediate_transport,
Expand Down Expand Up @@ -1356,8 +1359,7 @@ def prepare_execute_through_host(transp, client, exit_code):
open_channel, intermediate_channel
)

def test_execute_through_host_no_creds(
self, transp, client, policy, logger):
def test_01_execute_through_host_no_creds(self, transp, client, policy, logger):
target = '127.0.0.2'
exit_code = 0

Expand Down Expand Up @@ -1398,10 +1400,7 @@ def test_execute_through_host_no_creds(
transp.assert_called_once_with(intermediate_channel)
open_session.assert_called_once()
transport.assert_has_calls((
mock.call.connect(
username=username, password=password, pkey=None,
key_filename=None, passphrase=None,
),
mock.call.connect(username=username, password=password, pkey=None),
mock.call.open_session()
))
channel.assert_has_calls((
Expand All @@ -1414,8 +1413,7 @@ def test_execute_through_host_no_creds(
mock.call.close()
))

def test_execute_through_host_auth(
self, transp, client, policy, logger):
def test_02_execute_through_host_auth(self, transp, client, policy, logger):
_login = 'cirros'
_password = 'cubswin:)'

Expand All @@ -1442,7 +1440,9 @@ def test_execute_through_host_auth(
port=port,
auth=exec_helpers.SSHAuth(
username=username,
password=password
password=password,
key_filename='~/fake_key',
passphrase='fake_passphrase'
))

# noinspection PyTypeChecker
Expand All @@ -1455,10 +1455,7 @@ def test_execute_through_host_auth(
transp.assert_called_once_with(intermediate_channel)
open_session.assert_called_once()
transport.assert_has_calls((
mock.call.connect(
username=_login, password=_password, pkey=None,
key_filename=None, passphrase=None,
),
mock.call.connect(username=_login, password=_password, pkey=None),
mock.call.open_session()
))
channel.assert_has_calls((
Expand All @@ -1473,8 +1470,7 @@ def test_execute_through_host_auth(


@mock.patch('exec_helpers._ssh_client_base.logger', autospec=True)
@mock.patch(
'paramiko.AutoAddPolicy', autospec=True, return_value='AutoAddPolicy')
@mock.patch('paramiko.AutoAddPolicy', autospec=True, return_value='AutoAddPolicy')
@mock.patch('paramiko.SSHClient', autospec=True)
class TestSftp(unittest.TestCase):
def tearDown(self):
Expand Down
Loading