Skip to content

Commit

Permalink
Adding some type annotations to clustershell.py
Browse files Browse the repository at this point in the history
Change-Id: I652b6b506fa3f500c4a2aa31ecb1f145592f51a1
  • Loading branch information
Guillaume Lederrey committed Sep 29, 2020
1 parent d46cf8b commit 17a19e7
Showing 1 changed file with 23 additions and 19 deletions.
42 changes: 23 additions & 19 deletions cumin/transports/clustershell.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import threading

from collections import Counter, defaultdict
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union

from ClusterShell import Event, Task
from tqdm import tqdm
Expand All @@ -31,15 +31,15 @@ class defined in ``cumin/transports/clustershell.py``, implementing its abstract
object the handler to the transport.
"""

def __init__(self, config, target):
def __init__(self, config: dict, target: Target) -> None:
"""Worker ClusterShell constructor.
:Parameters:
according to parent :py:meth:`cumin.transports.BaseWorker.__init__`.
"""
super().__init__(config, target)
self.task = Task.task_self() # Initialize a ClusterShell task
self._handler_instance = None
self._handler_instance: Event.EventHandler = None

# Set any ClusterShell task options
for key, value in config.get('clustershell', {}).items():
Expand All @@ -48,7 +48,7 @@ def __init__(self, config, target):
else:
self.task.set_info(key, value)

def execute(self):
def execute(self) -> int:
"""Execute the commands on all the targets using the handler.
Concrete implementation of parent abstract method.
Expand All @@ -72,7 +72,8 @@ def execute(self):
self.task.shell(self.commands[0].command, nodes=self.target.first_batch, handler=self._handler_instance,
timeout=self.commands[0].timeout, stdin=False)

return_value = 0
# FIXME: return_value should not be optional
return_value: Optional[int] = 0
try:
self.task.run(timeout=self.timeout, stdin=False)
self.task.join()
Expand Down Expand Up @@ -100,7 +101,7 @@ def get_results(self):
yield nodeset_fromlist(nodelist), output

@property
def handler(self):
def handler(self) -> Optional[Type['BaseEventHandler']]:
"""Concrete implementation of parent abstract getter and setter.
Accepted values for the setter:
Expand All @@ -115,7 +116,7 @@ def handler(self):
return self._handler

@handler.setter
def handler(self, value):
def handler(self, value: Union[Type['BaseEventHandler'], str]) -> None:
"""Setter for the `handler` property. The relative documentation is in the getter."""
if isinstance(value, type) and issubclass(value, BaseEventHandler):
self._handler = value
Expand Down Expand Up @@ -144,7 +145,7 @@ class Node:
"""

def __init__(self, name, commands):
def __init__(self, name: str, commands: List[Command]):
"""Node class constructor with default values.
Arguments:
Expand Down Expand Up @@ -390,7 +391,7 @@ def __init__(self, target: Target, commands: List[Command], success_threshold: f
self.lock = threading.Lock() # Used to update instance variables coherently from within callbacks

# Execution management variables
self.return_value = None
self.return_value: Optional[int] = None
self.commands = commands
self.kwargs = kwargs # Allow to store custom parameters from subclasses without changing the signature
self.counters: Dict[str, int] = Counter()
Expand All @@ -404,7 +405,8 @@ def __init__(self, target: Target, commands: List[Command], success_threshold: f
for node_name in target.first_batch:
self.nodes[node_name].state.update(State.scheduled)

self.progress = ProgressBars() if progress_bars else NoProgress()
# TODO: introduce a super type to ProgressBars / NoProgress
self.progress: Union[ProgressBars, NoProgress] = ProgressBars() if progress_bars else NoProgress()
self.reporter = Reporter()

def close(self, task):
Expand Down Expand Up @@ -542,7 +544,8 @@ class SyncEventHandler(BaseEventHandler):
enough nodes before proceeding with the next one.
"""

def __init__(self, target, commands, success_threshold=1.0, progress_bars=True, **kwargs):
def __init__(self, target: Target, commands: List[Command], success_threshold: float = 1.0,
progress_bars: bool = True, **kwargs: Any) -> None:
"""Define a custom ClusterShell event handler to execute commands synchronously.
:Parameters:
Expand All @@ -555,7 +558,7 @@ def __init__(self, target, commands, success_threshold=1.0, progress_bars=True,
self.start_command()
self.aborted = False

def start_command(self, schedule=False):
def start_command(self, schedule: bool = False) -> None:
"""Initialize progress bars and variables for this command execution.
Executed at the start of each command.
Expand Down Expand Up @@ -586,7 +589,7 @@ def start_command(self, schedule=False):
Task.task_self().flush_buffers()
Task.task_self().shell(command.command, nodes=first_batch_set, handler=self, timeout=command.timeout)

def end_command(self):
def end_command(self) -> bool:
"""Command terminated, print the result and schedule the next command if criteria are met.
Executed at the end of each command inside a lock.
Expand Down Expand Up @@ -627,7 +630,7 @@ def end_command(self):

return True

def on_timeout(self, task):
def on_timeout(self, task: Task) -> None:
"""Override parent class `on_timeout` method to run `end_command`.
:Parameters:
Expand Down Expand Up @@ -729,7 +732,7 @@ def ev_timer(self, timer): # noqa, mccabe: MC0001 too complex (15) FIXME
if restart:
self.start_command(schedule=True)

def close(self, task):
def close(self, task: Task) -> None:
"""Concrete implementation of parent abstract method to print the success nodes report.
:Parameters:
Expand All @@ -754,7 +757,8 @@ class AsyncEventHandler(BaseEventHandler):
orchestration between the nodes.
"""

def __init__(self, target, commands, success_threshold=1.0, progress_bars=True, **kwargs):
def __init__(self, target: Target, commands: List[Command], success_threshold: float = 1.0,
progress_bars: bool = True, **kwargs: Any) -> None:
"""Define a custom ClusterShell event handler to execute commands asynchronously between nodes.
:Parameters:
Expand Down Expand Up @@ -836,7 +840,7 @@ def ev_timer(self, timer):
else:
self.logger.debug('No more nodes left')

def close(self, task):
def close(self, task: Task) -> None:
"""Concrete implementation of parent abstract method to print the nodes reports and close progress bars.
:Parameters:
Expand Down Expand Up @@ -864,8 +868,8 @@ def close(self, task):
self.return_value = 1


worker_class = ClusterShellWorker # pylint: disable=invalid-name
worker_class: Type[BaseWorker] = ClusterShellWorker # pylint: disable=invalid-name
"""Required by the transport auto-loader in :py:meth:`cumin.transport.Transport.new`."""

DEFAULT_HANDLERS = {'sync': SyncEventHandler, 'async': AsyncEventHandler}
DEFAULT_HANDLERS: Dict[str, Type[Event.EventHandler]] = {'sync': SyncEventHandler, 'async': AsyncEventHandler}
"""dict: mapping of available default event handlers for :py:class:`ClusterShellWorker`."""

0 comments on commit 17a19e7

Please sign in to comment.