From 2631fab0b70c419af6cfa066b2a19b504720d2b1 Mon Sep 17 00:00:00 2001 From: Daniel Kertesz Date: Sun, 10 Mar 2019 22:26:08 +0000 Subject: [PATCH] support for cluster commands limited --- fox/cluster.py | 35 +++++++++++++++++++++++++---------- fox/connection.py | 4 +++- 2 files changed, 28 insertions(+), 11 deletions(-) diff --git a/fox/cluster.py b/fox/cluster.py index 2e7ead3..47cee6c 100644 --- a/fox/cluster.py +++ b/fox/cluster.py @@ -29,29 +29,44 @@ def __init__(self, *hosts): self.hosts = hosts self._connections = [_get_connection(host, use_cache=False) for host in self.hosts] - def run(self, command): - return run_in_loop(self._run(command)) + def run(self, command, limit=0): + return run_in_loop(self._run(command, limit)) - async def _run(self, command): + async def _run(self, command, limit=0): bar = tqdm.tqdm(total=len(self.hosts)) qbar = asyncio.Queue() + futures_done = [] + aws = set() - results = await asyncio.gather( - *[self._do(qbar, connection, command) for connection in self._connections], - _update_bar(bar, len(self.hosts), qbar), - return_exceptions=True, - ) - for connection, result in results[:-1]: + bar_task = asyncio.ensure_future(_update_bar(bar, len(self.hosts), qbar)) + + for connection in self._connections: + aws.add(asyncio.ensure_future(self._do(qbar, connection, command))) + if limit and len(aws) >= limit: + done, pending = await asyncio.wait(aws, return_when=asyncio.FIRST_COMPLETED) + aws = pending + futures_done.extend(done) + + if len(aws): + done, pending = await asyncio.wait(aws, return_when=asyncio.ALL_COMPLETED) + futures_done.extend(done) + + _ = await asyncio.wait({bar_task}) + + results = [future.result() for future in futures_done] + for connection, result in results: if isinstance(result, CommandResult): print(f"output from {connection.nickname}: {result.stdout}", end="") else: print(f"command failed on {connection.nickname}: {result}") + return results + async def _do(self, queue, connection, command): try: result = await connection._run(command, echo=False) except Exception as exc: - print(f"task on {connection} failed: {exc}") + print(f"Task on {connection} failed: {exc}") result = None await queue.put(1) diff --git a/fox/connection.py b/fox/connection.py index c2cfba0..beb59bd 100644 --- a/fox/connection.py +++ b/fox/connection.py @@ -363,7 +363,9 @@ def _get_connection(name=None, use_cache=True) -> Connection: # NOTE: we only cache connections created here, and maybe the tunnels. # maybe by default we should not re-use the tunnels, as the default behavior of SSH - c = Connection(ssh_options["hostname"], ssh_options["user"], ssh_options["port"], **args) + c = Connection( + ssh_options["hostname"], ssh_options["user"], ssh_options["port"], nickname=name, **args + ) if use_cache: _connections_cache[name] = c return c