Skip to content

Commit

Permalink
clustershell: instantiate progress bar earlier
Browse files Browse the repository at this point in the history
* Instead of adding logic regarding which progress bar instance to use
  in the event handler, move the decision upwards in the stack and let
  the event handler instances be called directly with a progress bar
  instance.

Change-Id: Ic35d86ceaa2b072d32d603903a619e535576faaf
  • Loading branch information
volans- committed Apr 27, 2021
1 parent 33a167c commit 5d56e11
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 15 deletions.
22 changes: 13 additions & 9 deletions cumin/tests/unit/transports/test_clustershell.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import pytest

from cumin import CuminError, nodeset
from cumin.transports import BaseWorker, clustershell, Command, State, Target, TqdmProgressBars, WorkerError
from cumin.transports import (BaseExecutionProgress, BaseWorker, clustershell, Command, State, Target,
TqdmProgressBars, WorkerError)
from cumin.transports.clustershell import Node, NullReporter, TqdmReporter


Expand Down Expand Up @@ -245,10 +246,12 @@ def setup_method(self, *args): # pylint: disable=arguments-differ
self.worker.nodes = self.target.hosts
self.handler = None
self.args = args
self.progress_bars = mock.MagicMock(spec_set=BaseExecutionProgress)

def test_close(self):
"""Calling close should raise NotImplementedError."""
self.handler = clustershell.BaseEventHandler(self.target, self.commands, TqdmReporter())
self.handler = clustershell.BaseEventHandler(self.target, self.commands, TqdmReporter(),
progress_bars=self.progress_bars)
with pytest.raises(NotImplementedError):
self.handler.close(self.worker)

Expand All @@ -272,7 +275,8 @@ class TestConcreteBaseEventHandler(TestBaseEventHandler):
def setup_method(self, _, tqdm): # pylint: disable=arguments-differ
"""Initialize default properties and instances."""
super().setup_method()
self.handler = ConcreteBaseEventHandler(self.target, self.commands, TqdmReporter())
self.handler = ConcreteBaseEventHandler(self.target, self.commands, TqdmReporter(),
progress_bars=self.progress_bars)
self.worker.eh = self.handler
assert not tqdm.write.called

Expand Down Expand Up @@ -314,7 +318,8 @@ def test_ev_read_many_hosts(self, tqdm):
def test_ev_read_single_host(self, tqdm):
"""Calling ev_read() should print the worker message if matching a single host."""
self.target = Target(nodeset('node1'))
self.handler = ConcreteBaseEventHandler(self.target, self.commands, TqdmReporter())
self.handler = ConcreteBaseEventHandler(self.target, self.commands, TqdmReporter(),
progress_bars=self.progress_bars)

output = b'node1 output'
self.worker.nodes = self.target.hosts
Expand All @@ -341,8 +346,7 @@ def setup_method(self, _, tqdm, logger): # pylint: disable=arguments-differ
"""Initialize default properties and instances."""
super().setup_method()
self.handler = clustershell.SyncEventHandler(self.target, self.commands, TqdmReporter(),
success_threshold=1)
self.handler.progress = mock.Mock(spec_set=TqdmProgressBars)
progress_bars=self.progress_bars, success_threshold=1)
self.worker.eh = self.handler
self.logger = logger
assert not tqdm.write.called
Expand Down Expand Up @@ -428,13 +432,13 @@ def test_close(self, tqdm): # pylint: disable=arguments-differ
class TestAsyncEventHandler(TestBaseEventHandler):
"""AsyncEventHandler test class."""

@mock.patch('cumin.transports.clustershell.TqdmProgressBars')
@mock.patch('cumin.transports.clustershell.logging')
@mock.patch('cumin.transports.clustershell.tqdm')
def setup_method(self, _, tqdm, logger, progress): # pylint: disable=arguments-differ,unused-argument
def setup_method(self, _, tqdm, logger): # pylint: disable=arguments-differ,unused-argument
"""Initialize default properties and instances."""
super().setup_method()
self.handler = clustershell.AsyncEventHandler(self.target, self.commands, TqdmReporter())
self.handler = clustershell.AsyncEventHandler(self.target, self.commands, TqdmReporter(),
progress_bars=self.progress_bars)
self.worker.eh = self.handler
self.logger = logger
assert not tqdm.write.called
Expand Down
13 changes: 7 additions & 6 deletions cumin/transports/clustershell.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,10 @@ def execute(self) -> int:
# Instantiate handler
# Schedule only the first command for the first batch, the following ones must be handled by the EventHandler
reporter = self._reporter() # Instantiate a new Reporter at each execution
progress_bars_instance = TqdmProgressBars() if self._progress_bars else NoProgress()
self._handler_instance = self.handler( # pylint: disable=not-callable
self.target, self.commands, reporter=reporter, success_threshold=self.success_threshold,
progress_bars=self._progress_bars)
progress_bars=progress_bars_instance)

self.logger.info(
"Executing commands %s on '%d' hosts: %s", self.commands, len(self.target.hosts), self.target.hosts)
Expand Down Expand Up @@ -504,16 +505,16 @@ class BaseEventHandler(Event.EventHandler):

# FIXME: not sure what the type of **kwargs should be
def __init__(self, target: Target, commands: List[Command], reporter: BaseReporter,
success_threshold: float = 1.0, progress_bars: bool = True, **kwargs: Any) -> None:
progress_bars: BaseExecutionProgress, success_threshold: float = 1.0, **kwargs: Any) -> None:
"""Event handler ClusterShell extension constructor.
Arguments:
target (cumin.transports.Target): a Target instance.
commands (list): the list of Command objects that has to be executed on the nodes.
reporter (cumin.transports.clustershell.BaseReporter): reporter used to output progress.
progress_bars (BaseExecutionProgress): the progress bars instance.
success_threshold (float, optional): the success threshold, a :py:class:`float` between ``0`` and ``1``,
to consider the execution successful.
progress_bars (bool): should progress bars be displayed
**kwargs (optional): additional keyword arguments that might be used by derived classes.
"""
Expand All @@ -538,7 +539,7 @@ def __init__(self, target: Target, commands: List[Command], reporter: BaseReport
for node_name in target.first_batch:
self.nodes[node_name].state.update(State.scheduled)

self.progress: BaseExecutionProgress = TqdmProgressBars() if progress_bars else NoProgress()
self.progress = progress_bars
self.reporter = reporter

def close(self, task):
Expand Down Expand Up @@ -677,7 +678,7 @@ class SyncEventHandler(BaseEventHandler):
"""

def __init__(self, target: Target, commands: List[Command], reporter: BaseReporter,
success_threshold: float = 1.0, progress_bars: bool = True, **kwargs: Any) -> None:
progress_bars: BaseExecutionProgress, success_threshold: float = 1.0, **kwargs: Any) -> None:
"""Define a custom ClusterShell event handler to execute commands synchronously.
:Parameters:
Expand Down Expand Up @@ -888,7 +889,7 @@ class AsyncEventHandler(BaseEventHandler):
"""

def __init__(self, target: Target, commands: List[Command], reporter: BaseReporter,
success_threshold: float = 1.0, progress_bars: bool = True, **kwargs: Any) -> None:
progress_bars: BaseExecutionProgress, success_threshold: float = 1.0, **kwargs: Any) -> None:
"""Define a custom ClusterShell event handler to execute commands asynchronously between nodes.
:Parameters:
Expand Down

0 comments on commit 5d56e11

Please sign in to comment.