Skip to content

Commit d812ff3

Browse files
authoredAug 20, 2022
Perf changes (#357)
* Performance improvements for reading output in all clients. * Output reading for all clients has been changed to be less prone to race conditions. * Parallel clients now read a common private key only once, reusing it for all clients it applies to, to improve performance. * Updated changelog. * Added test case for joining on parallel clients without ever running run_command. Updated join so that it does not raise exception in that case.
1 parent 1b44e9a commit d812ff3

File tree

10 files changed

+198
-138
lines changed

10 files changed

+198
-138
lines changed
 

‎Changelog.rst

+20
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,26 @@
11
Change Log
22
============
33

4+
2.12.0
5+
+++++++
6+
7+
Changes
8+
--------
9+
10+
* Added ``alias`` optional parameter to ``SSHClient`` and ``HostConfig`` for passing through from parallel clients.
11+
Used to set an SSH host name alias, for cases where the real host name is the same and there is a need to
12+
differentiate output from otherwise identical host names - #355. Thank you @simonfelding.
13+
* Parallel clients now read a common private key only once, reusing it for all clients it applies to,
14+
to improve performance.
15+
* Performance improvements for all clients when reading output.
16+
* Output reading for all clients has been changed to be less prone to race conditions.
17+
18+
Fixes
19+
------
20+
21+
* Calling ``ParallelSSHClient.join`` without ever running ``run_command`` would raise exception. Is now a no-op.
22+
23+
424
2.11.1
525
+++++++
626

‎pssh/clients/base/parallel.py

+36-25
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from gevent import joinall, spawn, Timeout as GTimeout
2424
from gevent.hub import Hub
2525

26-
from ..common import _validate_pkey_path
26+
from ..common import _validate_pkey_path, _validate_pkey
2727
from ...config import HostConfig
2828
from ...constants import DEFAULT_RETRIES, RETRY_DELAY
2929
from ...exceptions import HostArgumentError, Timeout, ShellError, HostConfigError
@@ -39,7 +39,7 @@ class BaseParallelSSHClient(object):
3939
def __init__(self, hosts, user=None, password=None, port=None, pkey=None,
4040
allow_agent=True,
4141
num_retries=DEFAULT_RETRIES,
42-
timeout=120, pool_size=10,
42+
timeout=120, pool_size=100,
4343
host_config=None, retry_delay=RETRY_DELAY,
4444
identity_auth=True,
4545
ipv6_only=False,
@@ -64,7 +64,8 @@ def __init__(self, hosts, user=None, password=None, port=None, pkey=None,
6464
self.user = user
6565
self.password = password
6666
self.port = port
67-
self.pkey = pkey
67+
self.pkey = _validate_pkey(pkey)
68+
self.__pkey_data = self._load_pkey_data(pkey) if pkey is not None else None
6869
self.num_retries = num_retries
6970
self.timeout = timeout
7071
self._host_clients = {}
@@ -113,9 +114,26 @@ def hosts(self, _hosts):
113114
self._host_clients.pop((i, host), None)
114115
self._hosts = _hosts
115116

117+
def __del__(self):
118+
self.disconnect()
119+
120+
def disconnect(self):
121+
if not hasattr(self, '_host_clients'):
122+
return
123+
for s_client in self._host_clients.values():
124+
try:
125+
s_client.disconnect()
126+
except Exception as ex:
127+
logger.debug("Client disconnect failed with %s", ex)
128+
pass
129+
del s_client
130+
116131
def _check_host_config(self):
117132
if self.host_config is None:
118133
return
134+
if not isinstance(self.host_config, list):
135+
raise HostConfigError("Host configuration of type %s is invalid - valid types are List[HostConfig]",
136+
type(self.host_config))
119137
host_len = len(self.hosts)
120138
if host_len != len(self.host_config):
121139
raise ValueError(
@@ -231,7 +249,7 @@ def _get_output_from_cmds(self, cmds, raise_error=False):
231249

232250
def _get_output_from_greenlet(self, cmd_i, cmd, raise_error=False):
233251
host = self.hosts[cmd_i]
234-
alias = self._get_host_config(cmd_i, host).alias
252+
alias = self._get_host_config(cmd_i).alias
235253
try:
236254
host_out = cmd.get()
237255
return host_out
@@ -256,7 +274,7 @@ def get_last_output(self, cmds=None):
256274
return self._get_output_from_cmds(
257275
cmds, raise_error=False)
258276

259-
def _get_host_config(self, host_i, host):
277+
def _get_host_config(self, host_i):
260278
if self.host_config is None:
261279
config = HostConfig(
262280
user=self.user, port=self.port, password=self.password, private_key=self.pkey,
@@ -275,17 +293,13 @@ def _get_host_config(self, host_i, host):
275293
alias=None,
276294
)
277295
return config
278-
elif not isinstance(self.host_config, list):
279-
raise HostConfigError("Host configuration of type %s is invalid - valid types are list[HostConfig]",
280-
type(self.host_config))
281296
config = self.host_config[host_i]
282297
return config
283298

284299
def _run_command(self, host_i, host, command, sudo=False, user=None,
285300
shell=None, use_pty=False,
286301
encoding='utf-8', read_timeout=None):
287302
"""Make SSHClient if needed, run command on host"""
288-
logger.debug("_run_command with read timeout %s", read_timeout)
289303
try:
290304
_client = self._get_ssh_client(host_i, host)
291305
host_out = _client.run_command(
@@ -311,13 +325,13 @@ def connect_auth(self):
311325
:returns: list of greenlets to ``joinall`` with.
312326
:rtype: list(:py:mod:`gevent.greenlet.Greenlet`)
313327
"""
314-
cmds = [spawn(self._get_ssh_client, i, host) for i, host in enumerate(self.hosts)]
328+
cmds = [self.pool.spawn(self._get_ssh_client, i, host) for i, host in enumerate(self.hosts)]
315329
return cmds
316330

317331
def _consume_output(self, stdout, stderr):
318-
for line in stdout:
332+
for _ in stdout:
319333
pass
320-
for line in stderr:
334+
for _ in stderr:
321335
pass
322336

323337
def join(self, output=None, consume_output=False, timeout=None):
@@ -346,6 +360,9 @@ def join(self, output=None, consume_output=False, timeout=None):
346360
:rtype: ``None``"""
347361
if output is None:
348362
output = self.get_last_output()
363+
if output is None:
364+
logger.info("No last output to join on - run_command has never been run.")
365+
return
349366
elif not isinstance(output, list):
350367
raise ValueError("Unexpected output object type")
351368
cmds = [self.pool.spawn(self._join, host_out, timeout=timeout,
@@ -544,32 +561,26 @@ def _copy_remote_file(self, host_i, host, remote_file, local_file, recurse,
544561
return client.copy_remote_file(
545562
remote_file, local_file, recurse=recurse, **kwargs)
546563

547-
def _handle_greenlet_exc(self, func, host, *args, **kwargs):
548-
try:
549-
return func(*args, **kwargs)
550-
except Exception as ex:
551-
raise ex
552-
553564
def _get_ssh_client(self, host_i, host):
554565
logger.debug("Make client request for host %s, (host_i, host) in clients: %s",
555566
host, (host_i, host) in self._host_clients)
556567
_client = self._host_clients.get((host_i, host))
557568
if _client is not None:
558569
return _client
559-
cfg = self._get_host_config(host_i, host)
570+
cfg = self._get_host_config(host_i)
560571
_pkey = self.pkey if cfg.private_key is None else cfg.private_key
561572
_pkey_data = self._load_pkey_data(_pkey)
562573
_client = self._make_ssh_client(host, cfg, _pkey_data)
563574
self._host_clients[(host_i, host)] = _client
564575
return _client
565576

566577
def _load_pkey_data(self, _pkey):
567-
if isinstance(_pkey, str):
568-
_validate_pkey_path(_pkey)
569-
with open(_pkey, 'rb') as fh:
570-
_pkey_data = fh.read()
571-
return _pkey_data
572-
return _pkey
578+
if not isinstance(_pkey, str):
579+
return _pkey
580+
_pkey = _validate_pkey_path(_pkey)
581+
with open(_pkey, 'rb') as fh:
582+
_pkey_data = fh.read()
583+
return _pkey_data
573584

574585
def _make_ssh_client(self, host, cfg, _pkey_data):
575586
raise NotImplementedError

‎pssh/clients/base/single.py

+17-12
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,17 @@
2323
from gevent import sleep, socket, Timeout as GTimeout
2424
from gevent.hub import Hub
2525
from gevent.select import poll, POLLIN, POLLOUT
26-
27-
from ssh2.utils import find_eol
2826
from ssh2.exceptions import AgentConnectionError, AgentListIdentitiesError, \
2927
AgentAuthenticationError, AgentGetIdentityError
28+
from ssh2.utils import find_eol
3029

3130
from ..common import _validate_pkey
32-
from ...constants import DEFAULT_RETRIES, RETRY_DELAY
3331
from ..reader import ConcurrentRWBuffer
32+
from ...constants import DEFAULT_RETRIES, RETRY_DELAY
3433
from ...exceptions import UnknownHostError, AuthenticationError, \
3534
ConnectionError, Timeout, NoIPv6AddressFoundError
3635
from ...output import HostOutput, HostOutputBuffers, BufferData
3736

38-
3937
Hub.NOT_ERROR = (Exception,)
4038
host_logger = logging.getLogger('pssh.host_logger')
4139
logger = logging.getLogger(__name__)
@@ -287,15 +285,15 @@ def _connect(self, host, port, retries=1):
287285
raise unknown_ex from ex
288286
for i, (family, _type, proto, _, sock_addr) in enumerate(addr_info):
289287
try:
290-
return self._connect_socket(family, _type, proto, sock_addr, host, port, retries)
288+
return self._connect_socket(family, _type, sock_addr, host, port, retries)
291289
except ConnectionRefusedError as ex:
292290
if i+1 == len(addr_info):
293291
logger.error("No available addresses from %s", [addr[4] for addr in addr_info])
294292
ex.args += (host, port)
295293
raise
296294
continue
297295

298-
def _connect_socket(self, family, _type, proto, sock_addr, host, port, retries):
296+
def _connect_socket(self, family, _type, sock_addr, host, port, retries):
299297
self.sock = socket.socket(family, _type)
300298
if self.timeout:
301299
self.sock.settimeout(self.timeout)
@@ -428,6 +426,8 @@ def read_stderr(self, stderr_buffer, timeout=None):
428426
429427
:param stderr_buffer: Buffer to read from.
430428
:type stderr_buffer: :py:class:`pssh.clients.reader.ConcurrentRWBuffer`
429+
:param timeout: Timeout in seconds - defaults to no timeout.
430+
:type timeout: int or float
431431
:rtype: generator
432432
"""
433433
logger.debug("Reading from stderr buffer, timeout=%s", timeout)
@@ -439,6 +439,8 @@ def read_output(self, stdout_buffer, timeout=None):
439439
440440
:param stdout_buffer: Buffer to read from.
441441
:type stdout_buffer: :py:class:`pssh.clients.reader.ConcurrentRWBuffer`
442+
:param timeout: Timeout in seconds - defaults to no timeout.
443+
:type timeout: int or float
442444
:rtype: generator
443445
"""
444446
logger.debug("Reading from stdout buffer, timeout=%s", timeout)
@@ -492,14 +494,16 @@ def read_output_buffer(self, output_buffer, prefix=None,
492494
encoding='utf-8'):
493495
"""Read from output buffers and log to ``host_logger``.
494496
495-
:param output_buffer: Iterator containing buffer
497+
:param output_buffer: Iterator containing buffer.
496498
:type output_buffer: iterator
497-
:param prefix: String to prefix log output to ``host_logger`` with
499+
:param prefix: String to prefix log output to ``host_logger`` with.
498500
:type prefix: str
499-
:param callback: Function to call back once buffer is depleted:
501+
:param callback: Function to call back once buffer is depleted.
500502
:type callback: function
501-
:param callback_args: Arguments for call back function
503+
:param callback_args: Arguments for call back function.
502504
:type callback_args: tuple
505+
:param encoding: Encoding for output.
506+
:type encoding: str
503507
"""
504508
prefix = '' if prefix is None else prefix
505509
for line in output_buffer:
@@ -553,7 +557,7 @@ def run_command(self, command, sudo=False, user=None,
553557
host_out = self._make_host_output(channel, encoding, _timeout)
554558
return host_out
555559

556-
def _eagain_write_errcode(self, write_func, data, eagain, timeout=None):
560+
def _eagain_write_errcode(self, write_func, data, eagain):
557561
data_len = len(data)
558562
total_written = 0
559563
while total_written < data_len:
@@ -570,9 +574,10 @@ def _eagain_errcode(self, func, eagain, *args, **kwargs):
570574
while ret == eagain:
571575
self.poll()
572576
ret = func(*args, **kwargs)
577+
sleep()
573578
return ret
574579

575-
def _eagain_write(self, write_func, data, timeout=None):
580+
def _eagain_write(self, write_func, data):
576581
raise NotImplementedError
577582

578583
def _eagain(self, func, *args, **kwargs):

‎pssh/clients/native/parallel.py

+10-21
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,6 @@ def __init__(self, hosts, user=None, password=None, port=22, pkey=None,
127127
identity_auth=identity_auth,
128128
ipv6_only=ipv6_only,
129129
)
130-
self.pkey = _validate_pkey(pkey)
131130
self.proxy_host = proxy_host
132131
self.proxy_port = proxy_port
133132
self.proxy_pkey = _validate_pkey(proxy_pkey)
@@ -216,17 +215,6 @@ def run_command(self, command, sudo=False, user=None, stop_on_errors=True,
216215
read_timeout=read_timeout,
217216
)
218217

219-
def __del__(self):
220-
if not hasattr(self, '_host_clients'):
221-
return
222-
for s_client in self._host_clients.values():
223-
try:
224-
s_client.disconnect()
225-
except Exception as ex:
226-
logger.debug("Client disconnect failed with %s", ex)
227-
pass
228-
del s_client
229-
230218
def _make_ssh_client(self, host, cfg, _pkey_data):
231219
_client = SSHClient(
232220
host, user=cfg.user or self.user, password=cfg.password or self.password, port=cfg.port or self.port,
@@ -371,16 +359,12 @@ def copy_remote_file(self, remote_file, local_file, recurse=False,
371359
encoding=encoding)
372360

373361
def _scp_send(self, host_i, host, local_file, remote_file, recurse=False):
374-
self._get_ssh_client(host_i, host)
375-
return self._handle_greenlet_exc(
376-
self._host_clients[(host_i, host)].scp_send, host,
377-
local_file, remote_file, recurse=recurse)
362+
_client = self._get_ssh_client(host_i, host)
363+
return _client.scp_send(local_file, remote_file, recurse=recurse)
378364

379365
def _scp_recv(self, host_i, host, remote_file, local_file, recurse=False):
380-
self._get_ssh_client(host_i, host)
381-
return self._handle_greenlet_exc(
382-
self._host_clients[(host_i, host)].scp_recv, host,
383-
remote_file, local_file, recurse=recurse)
366+
_client = self._get_ssh_client(host_i, host)
367+
return _client.scp_recv(remote_file, local_file, recurse=recurse)
384368

385369
def scp_send(self, local_file, remote_file, recurse=False, copy_args=None):
386370
"""Copy local file to remote file in parallel via SCP.
@@ -405,6 +389,11 @@ def scp_send(self, local_file, remote_file, recurse=False, copy_args=None):
405389
:type local_file: str
406390
:param remote_file: Remote filepath on remote host to copy file to
407391
:type remote_file: str
392+
:param copy_args: (Optional) format local_file and remote_file strings
393+
with per-host arguments in ``copy_args``. ``copy_args`` length must
394+
equal length of host list -
395+
:py:class:`pssh.exceptions.HostArgumentError` is raised otherwise
396+
:type copy_args: tuple or list
408397
:param recurse: Whether or not to descend into directories recursively.
409398
:type recurse: bool
410399
@@ -416,7 +405,7 @@ def scp_send(self, local_file, remote_file, recurse=False, copy_args=None):
416405
"""
417406
copy_args = [{'local_file': local_file,
418407
'remote_file': remote_file}
419-
for i, host in enumerate(self.hosts)] \
408+
for _ in self.hosts] \
420409
if copy_args is None else copy_args
421410
local_file = "%(local_file)s"
422411
remote_file = "%(remote_file)s"

0 commit comments

Comments
 (0)
Failed to load comments.