diff --git a/ignite/contrib/handlers/base_logger.py b/ignite/contrib/handlers/base_logger.py index 1d983ee59415..9a602c4397a4 100644 --- a/ignite/contrib/handlers/base_logger.py +++ b/ignite/contrib/handlers/base_logger.py @@ -1,18 +1,19 @@ import numbers import warnings from abc import ABCMeta, abstractmethod -from typing import Any, Callable, List, Optional, Sequence, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Union import torch import torch.nn as nn from torch.optim import Optimizer -from ignite.engine import Engine, State +from ignite.engine import Engine, Events, State +from ignite.engine.events import CallableEventWithFilter, RemovableEventHandle class BaseHandler(metaclass=ABCMeta): @abstractmethod - def __call__(self, engine, logger, event_name): + def __call__(self, engine: Engine, logger: Any, event_name: Union[str, Events]) -> None: pass @@ -68,7 +69,7 @@ def __init__( if global_step_transform is None: - def global_step_transform(engine, event_name): + def global_step_transform(engine: Engine, event_name: Union[str, Events]) -> int: return engine.state.get_event_attrib_value(event_name) self.tag = tag @@ -76,7 +77,7 @@ def global_step_transform(engine, event_name): self.output_transform = output_transform self.global_step_transform = global_step_transform - def _setup_output_metrics(self, engine: Engine): + def _setup_output_metrics(self, engine: Engine) -> Dict[str, Any]: """Helper method to setup metrics to log """ metrics = {} @@ -108,14 +109,14 @@ class BaseWeightsScalarHandler(BaseHandler): Helper handler to log model's weights as scalars. """ - def __init__(self, model: nn.Module, reduction: Callable = torch.norm, tag: Optional[str] = None): + def __init__(self, model: nn.Module, reduction: Callable = torch.norm, tag: Optional[str] = None) -> None: if not isinstance(model, torch.nn.Module): raise TypeError("Argument model should be of type torch.nn.Module, " "but given {}".format(type(model))) if not callable(reduction): raise TypeError("Argument reduction should be callable, " "but given {}".format(type(reduction))) - def _is_0D_tensor(t: torch.Tensor): + def _is_0D_tensor(t: torch.Tensor) -> bool: return isinstance(t, torch.Tensor) and t.ndimension() == 0 # Test reduction function on a tensor @@ -147,7 +148,9 @@ class BaseLogger(metaclass=ABCMeta): """ - def attach(self, engine: Engine, log_handler: Callable, event_name: Any): + def attach( + self, engine: Engine, log_handler: Callable, event_name: Union[str, Events, CallableEventWithFilter] + ) -> RemovableEventHandle: """Attach the logger to the engine and execute `log_handler` function at `event_name` events. Args: @@ -167,7 +170,7 @@ def attach(self, engine: Engine, log_handler: Callable, event_name: Any): return engine.add_event_handler(event_name, log_handler, self, name) - def attach_output_handler(self, engine: Engine, event_name: Any, *args: Any, **kwargs: Any): + def attach_output_handler(self, engine: Engine, event_name: Any, *args: Any, **kwargs: Any) -> RemovableEventHandle: """Shortcut method to attach `OutputHandler` to the logger. Args: @@ -183,7 +186,7 @@ def attach_output_handler(self, engine: Engine, event_name: Any, *args: Any, **k """ return self.attach(engine, self._create_output_handler(*args, **kwargs), event_name=event_name) - def attach_opt_params_handler(self, engine: Engine, event_name: Any, *args: Any, **kwargs: Any): + def attach_opt_params_handler(self, engine: Engine, event_name: Any, *args: Any, **kwargs: Any) -> None: """Shortcut method to attach `OptimizerParamsHandler` to the logger. Args: @@ -200,18 +203,18 @@ def attach_opt_params_handler(self, engine: Engine, event_name: Any, *args: Any, self.attach(engine, self._create_opt_params_handler(*args, **kwargs), event_name=event_name) @abstractmethod - def _create_output_handler(self, engine: Engine, *args: Any, **kwargs: Any): + def _create_output_handler(self, engine: Engine, *args: Any, **kwargs: Any) -> Callable: pass @abstractmethod - def _create_opt_params_handler(self, *args: Any, **kwargs: Any): + def _create_opt_params_handler(self, *args: Any, **kwargs: Any) -> Callable: pass - def __enter__(self): + def __enter__(self) -> "BaseLogger": return self - def __exit__(self, type, value, traceback): + def __exit__(self, type: Any, value: Any, traceback: Any) -> None: self.close() - def close(self): + def close(self) -> None: pass diff --git a/ignite/contrib/handlers/lr_finder.py b/ignite/contrib/handlers/lr_finder.py index 149a32e60cdf..2ad702171b01 100644 --- a/ignite/contrib/handlers/lr_finder.py +++ b/ignite/contrib/handlers/lr_finder.py @@ -4,7 +4,7 @@ import tempfile import warnings from pathlib import Path -from typing import Callable, Mapping, Optional +from typing import Any, Callable, Dict, List, Mapping, Optional, Union import torch from torch.optim import Optimizer @@ -71,11 +71,11 @@ class FastaiLRFinder: fastai/lr_find: https://github.com/fastai/fastai """ - def __init__(self): + def __init__(self) -> None: self._diverge_flag = False - self._history = None + self._history = {} # type: Dict[str, List[Any]] self._best_loss = None - self._lr_schedule = None + self._lr_schedule = None # type: Optional[Union[LRScheduler, PiecewiseLinear]] self.logger = logging.getLogger(__name__) def _run( @@ -88,7 +88,7 @@ def _run( step_mode: str, smooth_f: float, diverge_th: float, - ): + ) -> None: self._history = {"lr": [], "loss": []} self._best_loss = None @@ -98,7 +98,7 @@ def _run( if num_iter is None: num_iter = trainer.state.epoch_length * trainer.state.max_epochs else: - max_iter = trainer.state.epoch_length * trainer.state.max_epochs + max_iter = trainer.state.epoch_length * trainer.state.max_epochs # type: ignore[operator] if num_iter > max_iter: warnings.warn( "Desired num_iter {} is unreachable with the current run setup of {} iteration " @@ -127,16 +127,16 @@ def _run( if not trainer.has_event_handler(self._lr_schedule): trainer.add_event_handler(Events.ITERATION_COMPLETED, self._lr_schedule, num_iter) - def _reset(self, trainer: Engine): + def _reset(self, trainer: Engine) -> None: self.logger.debug("Completed LR finder run") - trainer.remove_event_handler(self._lr_schedule, Events.ITERATION_COMPLETED) + trainer.remove_event_handler(self._lr_schedule, Events.ITERATION_COMPLETED) # type: ignore[arg-type] trainer.remove_event_handler(self._log_lr_and_loss, Events.ITERATION_COMPLETED) trainer.remove_event_handler(self._reached_num_iterations, Events.ITERATION_COMPLETED) - def _log_lr_and_loss(self, trainer: Engine, output_transform: Callable, smooth_f: float, diverge_th: float): + def _log_lr_and_loss(self, trainer: Engine, output_transform: Callable, smooth_f: float, diverge_th: float) -> None: output = trainer.state.output loss = output_transform(output) - lr = self._lr_schedule.get_param() + lr = self._lr_schedule.get_param() # type: ignore[union-attr] self._history["lr"].append(lr) if trainer.state.iteration == 1: self._best_loss = loss @@ -148,16 +148,16 @@ def _log_lr_and_loss(self, trainer: Engine, output_transform: Callable, smooth_f self._history["loss"].append(loss) # Check if the loss has diverged; if it has, stop the trainer - if self._history["loss"][-1] > diverge_th * self._best_loss: + if self._history["loss"][-1] > diverge_th * self._best_loss: # type: ignore[operator] self._diverge_flag = True self.logger.info("Stopping early, the loss has diverged") trainer.terminate() - def _reached_num_iterations(self, trainer: Engine, num_iter: int): + def _reached_num_iterations(self, trainer: Engine, num_iter: int) -> None: if trainer.state.iteration > num_iter: trainer.terminate() - def _warning(self, _): + def _warning(self, _: Any) -> None: if not self._diverge_flag: warnings.warn( "Run completed without loss diverging, increase end_lr, decrease diverge_th or look" @@ -165,7 +165,7 @@ def _warning(self, _): UserWarning, ) - def _detach(self, trainer: Engine): + def _detach(self, trainer: Engine) -> None: """ Detaches lr_finder from trainer. @@ -180,13 +180,13 @@ def _detach(self, trainer: Engine): if trainer.has_event_handler(self._reset, Events.COMPLETED): trainer.remove_event_handler(self._reset, Events.COMPLETED) - def get_results(self): + def get_results(self) -> Dict[str, List[Any]]: """ Returns: dictionary with loss and lr logs fromm the previous run """ return self._history - def plot(self, skip_start: int = 10, skip_end: int = 5, log_lr: bool = True): + def plot(self, skip_start: int = 10, skip_end: int = 5, log_lr: bool = True) -> None: """Plots the learning rate range test. This method requires `matplotlib` package to be installed: @@ -211,7 +211,7 @@ def plot(self, skip_start: int = 10, skip_end: int = 5, log_lr: bool = True): "Please install it with command: \n pip install matplotlib" ) - if self._history is None: + if not self._history: raise RuntimeError("learning rate finder didn't run yet so results can't be plotted") if skip_start < 0: @@ -239,11 +239,11 @@ def plot(self, skip_start: int = 10, skip_end: int = 5, log_lr: bool = True): plt.ylabel("Loss") plt.show() - def lr_suggestion(self): + def lr_suggestion(self) -> Any: """ Returns: learning rate at the minimum numerical gradient """ - if self._history is None: + if not self._history: raise RuntimeError("learning rate finder didn't run yet so lr_suggestion can't be returned") loss = self._history["loss"] grads = torch.tensor([loss[i] - loss[i - 1] for i in range(1, len(loss))]) @@ -261,7 +261,7 @@ def attach( step_mode: str = "exp", smooth_f: float = 0.05, diverge_th: float = 5.0, - ): + ) -> Any: """Attaches lr_finder to a given trainer. It also resets model and optimizer at the end of the run. Usage: @@ -372,12 +372,12 @@ class _ExponentialLR(_LRScheduler): """ - def __init__(self, optimizer: Optimizer, end_lr: float, num_iter: int, last_epoch: int = -1): + def __init__(self, optimizer: Optimizer, end_lr: float, num_iter: int, last_epoch: int = -1) -> None: self.end_lr = end_lr self.num_iter = num_iter super(_ExponentialLR, self).__init__(optimizer, last_epoch) - def get_lr(self): - curr_iter = self.last_epoch + 1 + def get_lr(self) -> List[float]: # type: ignore + curr_iter = self.last_epoch + 1 # type: ignore[attr-defined] r = curr_iter / self.num_iter - return [base_lr * (self.end_lr / base_lr) ** r for base_lr in self.base_lrs] + return [base_lr * (self.end_lr / base_lr) ** r for base_lr in self.base_lrs] # type: ignore[attr-defined] diff --git a/ignite/contrib/handlers/mlflow_logger.py b/ignite/contrib/handlers/mlflow_logger.py index bc8478f0d17a..d55b9698363f 100644 --- a/ignite/contrib/handlers/mlflow_logger.py +++ b/ignite/contrib/handlers/mlflow_logger.py @@ -1,12 +1,12 @@ import numbers import warnings -from typing import Any, Callable, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import torch from torch.optim import Optimizer from ignite.contrib.handlers.base_logger import BaseLogger, BaseOptimizerParamsHandler, BaseOutputHandler -from ignite.engine import Engine, EventEnum +from ignite.engine import Engine, Events from ignite.handlers import global_step_from_engine __all__ = ["MLflowLogger", "OutputHandler", "OptimizerParamsHandler", "global_step_from_engine"] @@ -86,7 +86,7 @@ class MLflowLogger(BaseLogger): ) """ - def __init__(self, tracking_uri: Optional[str] = None): + def __init__(self, tracking_uri: Optional[str] = None) -> None: try: import mlflow except ImportError: @@ -102,21 +102,21 @@ def __init__(self, tracking_uri: Optional[str] = None): if self.active_run is None: self.active_run = mlflow.start_run() - def __getattr__(self, attr: Any): + def __getattr__(self, attr: Any) -> Any: import mlflow return getattr(mlflow, attr) - def close(self): + def close(self) -> None: import mlflow mlflow.end_run() - def _create_output_handler(self, *args: Any, **kwargs: Any): + def _create_output_handler(self, *args: Any, **kwargs: Any) -> "OutputHandler": return OutputHandler(*args, **kwargs) - def _create_opt_params_handler(self, *args: Any, **kwargs: Any): + def _create_opt_params_handler(self, *args: Any, **kwargs: Any) -> "OptimizerParamsHandler": return OptimizerParamsHandler(*args, **kwargs) @@ -212,17 +212,17 @@ def __init__( metric_names: Optional[Union[str, List[str]]] = None, output_transform: Optional[Callable] = None, global_step_transform: Optional[Callable] = None, - ): + ) -> None: super(OutputHandler, self).__init__(tag, metric_names, output_transform, global_step_transform) - def __call__(self, engine: Engine, logger: MLflowLogger, event_name: Union[str, EventEnum]): + def __call__(self, engine: Engine, logger: MLflowLogger, event_name: Union[str, Events]) -> None: if not isinstance(logger, MLflowLogger): raise TypeError("Handler 'OutputHandler' works only with MLflowLogger") metrics = self._setup_output_metrics(engine) - global_step = self.global_step_transform(engine, event_name) + global_step = self.global_step_transform(engine, event_name) # type: ignore[misc] if not isinstance(global_step, int): raise TypeError( @@ -230,10 +230,10 @@ def __call__(self, engine: Engine, logger: MLflowLogger, event_name: Union[str, " Please check the output of global_step_transform.".format(type(global_step)) ) - rendered_metrics = {} + rendered_metrics = {} # type: Dict[str, float] for key, value in metrics.items(): if isinstance(value, numbers.Number): - rendered_metrics["{} {}".format(self.tag, key)] = value + rendered_metrics["{} {}".format(self.tag, key)] = value # type: ignore[assignment] elif isinstance(value, torch.Tensor) and value.ndimension() == 0: rendered_metrics["{} {}".format(self.tag, key)] = value.item() elif isinstance(value, torch.Tensor) and value.ndimension() == 1: @@ -290,10 +290,10 @@ class OptimizerParamsHandler(BaseOptimizerParamsHandler): tag (str, optional): common title for all produced plots. For example, 'generator' """ - def __init__(self, optimizer: Optimizer, param_name: str = "lr", tag: Optional[str] = None): + def __init__(self, optimizer: Optimizer, param_name: str = "lr", tag: Optional[str] = None) -> None: super(OptimizerParamsHandler, self).__init__(optimizer, param_name, tag) - def __call__(self, engine: Engine, logger: MLflowLogger, event_name: Union[str, EventEnum]): + def __call__(self, engine: Engine, logger: MLflowLogger, event_name: Union[str, Events]) -> None: if not isinstance(logger, MLflowLogger): raise TypeError("Handler OptimizerParamsHandler works only with MLflowLogger") diff --git a/ignite/contrib/handlers/neptune_logger.py b/ignite/contrib/handlers/neptune_logger.py index 18cf523ee62d..98216a26206c 100644 --- a/ignite/contrib/handlers/neptune_logger.py +++ b/ignite/contrib/handlers/neptune_logger.py @@ -15,7 +15,7 @@ BaseOutputHandler, BaseWeightsScalarHandler, ) -from ignite.engine import Engine, EventEnum +from ignite.engine import Engine, Events from ignite.handlers import global_step_from_engine from ignite.handlers.checkpoint import BaseSaveHandler @@ -176,13 +176,13 @@ def score_function(engine): """ - def __getattr__(self, attr: Any): + def __getattr__(self, attr: Any) -> Any: import neptune return getattr(neptune, attr) - def __init__(self, *args: Any, **kwargs: Any): + def __init__(self, *args: Any, **kwargs: Any) -> None: try: import neptune except ImportError: @@ -205,13 +205,13 @@ def __init__(self, *args: Any, **kwargs: Any): self.experiment = neptune.create_experiment(**self._experiment_kwargs) - def close(self): + def close(self) -> None: self.stop() - def _create_output_handler(self, *args: Any, **kwargs: Any): + def _create_output_handler(self, *args: Any, **kwargs: Any) -> "OutputHandler": return OutputHandler(*args, **kwargs) - def _create_opt_params_handler(self, *args: Any, **kwargs: Any): + def _create_opt_params_handler(self, *args: Any, **kwargs: Any) -> "OptimizerParamsHandler": return OptimizerParamsHandler(*args, **kwargs) @@ -323,17 +323,17 @@ def __init__( metric_names: Optional[Union[str, List[str]]] = None, output_transform: Optional[Callable] = None, global_step_transform: Optional[Callable] = None, - ): + ) -> None: super(OutputHandler, self).__init__(tag, metric_names, output_transform, global_step_transform) - def __call__(self, engine: Engine, logger: NeptuneLogger, event_name: Union[str, EventEnum]): + def __call__(self, engine: Engine, logger: NeptuneLogger, event_name: Union[str, Events]) -> None: if not isinstance(logger, NeptuneLogger): raise TypeError("Handler OutputHandler works only with NeptuneLogger") metrics = self._setup_output_metrics(engine) - global_step = self.global_step_transform(engine, event_name) + global_step = self.global_step_transform(engine, event_name) # type: ignore[misc] if not isinstance(global_step, int): raise TypeError( @@ -391,10 +391,10 @@ class OptimizerParamsHandler(BaseOptimizerParamsHandler): tag (str, optional): common title for all produced plots. For example, "generator" """ - def __init__(self, optimizer: Optimizer, param_name: str = "lr", tag: Optional[str] = None): + def __init__(self, optimizer: Optimizer, param_name: str = "lr", tag: Optional[str] = None) -> None: super(OptimizerParamsHandler, self).__init__(optimizer, param_name, tag) - def __call__(self, engine: Engine, logger: NeptuneLogger, event_name: Union[str, EventEnum]): + def __call__(self, engine: Engine, logger: NeptuneLogger, event_name: Union[str, Events]) -> None: if not isinstance(logger, NeptuneLogger): raise TypeError("Handler OptimizerParamsHandler works only with NeptuneLogger") @@ -445,10 +445,10 @@ class WeightsScalarHandler(BaseWeightsScalarHandler): """ - def __init__(self, model: nn.Module, reduction: Callable = torch.norm, tag: Optional[str] = None): + def __init__(self, model: nn.Module, reduction: Callable = torch.norm, tag: Optional[str] = None) -> None: super(WeightsScalarHandler, self).__init__(model, reduction, tag=tag) - def __call__(self, engine: Engine, logger: NeptuneLogger, event_name: Union[str, EventEnum]): + def __call__(self, engine: Engine, logger: NeptuneLogger, event_name: Union[str, Events]) -> None: if not isinstance(logger, NeptuneLogger): raise TypeError("Handler WeightsScalarHandler works only with NeptuneLogger") @@ -503,10 +503,10 @@ class GradsScalarHandler(BaseWeightsScalarHandler): """ - def __init__(self, model: nn.Module, reduction: Callable = torch.norm, tag: Optional[str] = None): + def __init__(self, model: nn.Module, reduction: Callable = torch.norm, tag: Optional[str] = None) -> None: super(GradsScalarHandler, self).__init__(model, reduction, tag=tag) - def __call__(self, engine: Engine, logger: NeptuneLogger, event_name: Any): + def __call__(self, engine: Engine, logger: NeptuneLogger, event_name: Union[str, Events]) -> None: if not isinstance(logger, NeptuneLogger): raise TypeError("Handler GradsScalarHandler works only with NeptuneLogger") @@ -590,7 +590,7 @@ def __call__(self, checkpoint: Mapping, filename: str, metadata: Optional[Mappin with tempfile.NamedTemporaryFile() as tmp: # we can not use tmp.name to open tmp.file twice on Win32 # https://docs.python.org/3/library/tempfile.html#tempfile.NamedTemporaryFile - torch.save(checkpoint, tmp.file) + torch.save(checkpoint, tmp.file) # type: ignore[attr-defined] self._logger.log_artifact(tmp.name, filename) @idist.one_rank_only(with_barrier=True) diff --git a/ignite/contrib/handlers/polyaxon_logger.py b/ignite/contrib/handlers/polyaxon_logger.py index 9d75dff12481..70a2416aa569 100644 --- a/ignite/contrib/handlers/polyaxon_logger.py +++ b/ignite/contrib/handlers/polyaxon_logger.py @@ -1,12 +1,12 @@ import numbers import warnings -from typing import Any, Callable, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import torch from torch.optim import Optimizer from ignite.contrib.handlers.base_logger import BaseLogger, BaseOptimizerParamsHandler, BaseOutputHandler -from ignite.engine import Engine, EventEnum +from ignite.engine import Engine, Events from ignite.handlers import global_step_from_engine __all__ = ["PolyaxonLogger", "OutputHandler", "OptimizerParamsHandler", "global_step_from_engine"] @@ -91,7 +91,7 @@ class PolyaxonLogger(BaseLogger): """ - def __init__(self, *args: Any, **kwargs: Any): + def __init__(self, *args: Any, **kwargs: Any) -> None: try: from polyaxon_client.tracking import Experiment except ImportError: @@ -102,13 +102,13 @@ def __init__(self, *args: Any, **kwargs: Any): self.experiment = Experiment(*args, **kwargs) - def __getattr__(self, attr: Any): + def __getattr__(self, attr: Any) -> Any: return getattr(self.experiment, attr) - def _create_output_handler(self, *args: Any, **kwargs: Any): + def _create_output_handler(self, *args: Any, **kwargs: Any) -> "OutputHandler": return OutputHandler(*args, **kwargs) - def _create_opt_params_handler(self, *args: Any, **kwargs: Any): + def _create_opt_params_handler(self, *args: Any, **kwargs: Any) -> "OptimizerParamsHandler": return OptimizerParamsHandler(*args, **kwargs) @@ -204,17 +204,17 @@ def __init__( metric_names: Optional[List[str]] = None, output_transform: Optional[Callable] = None, global_step_transform: Optional[Callable] = None, - ): + ) -> None: super(OutputHandler, self).__init__(tag, metric_names, output_transform, global_step_transform) - def __call__(self, engine: Engine, logger: PolyaxonLogger, event_name: Union[str, EventEnum]): + def __call__(self, engine: Engine, logger: PolyaxonLogger, event_name: Union[str, Events]) -> None: if not isinstance(logger, PolyaxonLogger): raise RuntimeError("Handler 'OutputHandler' works only with PolyaxonLogger") metrics = self._setup_output_metrics(engine) - global_step = self.global_step_transform(engine, event_name) + global_step = self.global_step_transform(engine, event_name) # type: ignore[misc] if not isinstance(global_step, int): raise TypeError( @@ -222,7 +222,7 @@ def __call__(self, engine: Engine, logger: PolyaxonLogger, event_name: Union[str " Please check the output of global_step_transform.".format(type(global_step)) ) - rendered_metrics = {"step": global_step} + rendered_metrics = {"step": global_step} # type: Dict[str, Union[float, numbers.Number]] for key, value in metrics.items(): if isinstance(value, numbers.Number): rendered_metrics["{}/{}".format(self.tag, key)] = value @@ -269,10 +269,10 @@ class OptimizerParamsHandler(BaseOptimizerParamsHandler): tag (str, optional): common title for all produced plots. For example, "generator" """ - def __init__(self, optimizer: Optimizer, param_name: str = "lr", tag: Optional[str] = None): + def __init__(self, optimizer: Optimizer, param_name: str = "lr", tag: Optional[str] = None) -> None: super(OptimizerParamsHandler, self).__init__(optimizer, param_name, tag) - def __call__(self, engine: Engine, logger: PolyaxonLogger, event_name: Union[str, EventEnum]): + def __call__(self, engine: Engine, logger: PolyaxonLogger, event_name: Union[str, Events]) -> None: if not isinstance(logger, PolyaxonLogger): raise RuntimeError("Handler OptimizerParamsHandler works only with PolyaxonLogger") diff --git a/ignite/contrib/handlers/stores.py b/ignite/contrib/handlers/stores.py index 6bd2b5167bdb..51212bb445e1 100644 --- a/ignite/contrib/handlers/stores.py +++ b/ignite/contrib/handlers/stores.py @@ -1,4 +1,4 @@ -from typing import Callable, Optional +from typing import Callable, List, Tuple, Union from ignite.engine import Engine, Events @@ -33,17 +33,17 @@ def log_training_results(engine): """ - def __init__(self, output_transform: Callable = lambda x: x): - self.data = None + def __init__(self, output_transform: Callable = lambda x: x) -> None: + self.data = [] # type: List[Union[int, Tuple[int, int]]] self.output_transform = output_transform - def reset(self): + def reset(self) -> None: self.data = [] - def update(self, engine: Engine): + def update(self, engine: Engine) -> None: output = self.output_transform(engine.state.output) self.data.append(output) - def attach(self, engine: Engine): + def attach(self, engine: Engine) -> None: engine.add_event_handler(Events.EPOCH_STARTED, self.reset) engine.add_event_handler(Events.ITERATION_COMPLETED, self.update) diff --git a/ignite/contrib/handlers/tensorboard_logger.py b/ignite/contrib/handlers/tensorboard_logger.py index 798e11769f56..9acc2f1bda31 100644 --- a/ignite/contrib/handlers/tensorboard_logger.py +++ b/ignite/contrib/handlers/tensorboard_logger.py @@ -13,7 +13,7 @@ BaseWeightsHistHandler, BaseWeightsScalarHandler, ) -from ignite.engine import Engine, EventEnum +from ignite.engine import Engine, EventEnum, Events from ignite.handlers import global_step_from_engine __all__ = [ @@ -149,12 +149,12 @@ class TensorboardLogger(BaseLogger): """ - def __init__(self, *args: Any, **kwargs: Any): + def __init__(self, *args: Any, **kwargs: Any) -> None: try: from tensorboardX import SummaryWriter except ImportError: try: - from torch.utils.tensorboard import SummaryWriter + from torch.utils.tensorboard import SummaryWriter # type: ignore[no-redef] except ImportError: raise RuntimeError( "This contrib module requires either tensorboardX or torch >= 1.2.0. " @@ -164,13 +164,13 @@ def __init__(self, *args: Any, **kwargs: Any): self.writer = SummaryWriter(*args, **kwargs) - def close(self): + def close(self) -> None: self.writer.close() - def _create_output_handler(self, *args: Any, **kwargs: Any): + def _create_output_handler(self, *args: Any, **kwargs: Any) -> "OutputHandler": return OutputHandler(*args, **kwargs) - def _create_opt_params_handler(self, *args: Any, **kwargs: Any): + def _create_opt_params_handler(self, *args: Any, **kwargs: Any) -> "OptimizerParamsHandler": return OptimizerParamsHandler(*args, **kwargs) @@ -266,17 +266,17 @@ def __init__( metric_names: Optional[List[str]] = None, output_transform: Optional[Callable] = None, global_step_transform: Optional[Callable] = None, - ): + ) -> None: super(OutputHandler, self).__init__(tag, metric_names, output_transform, global_step_transform) - def __call__(self, engine: Engine, logger: TensorboardLogger, event_name: Union[str, EventEnum]): + def __call__(self, engine: Engine, logger: TensorboardLogger, event_name: Union[str, EventEnum]) -> None: if not isinstance(logger, TensorboardLogger): raise RuntimeError("Handler 'OutputHandler' works only with TensorboardLogger") metrics = self._setup_output_metrics(engine) - global_step = self.global_step_transform(engine, event_name) + global_step = self.global_step_transform(engine, event_name) # type: ignore[misc] if not isinstance(global_step, int): raise TypeError( "global_step must be int, got {}." @@ -325,10 +325,10 @@ class OptimizerParamsHandler(BaseOptimizerParamsHandler): tag (str, optional): common title for all produced plots. For example, "generator" """ - def __init__(self, optimizer: Optimizer, param_name: str = "lr", tag: Optional[str] = None): + def __init__(self, optimizer: Optimizer, param_name: str = "lr", tag: Optional[str] = None) -> None: super(OptimizerParamsHandler, self).__init__(optimizer, param_name, tag) - def __call__(self, engine: Engine, logger: TensorboardLogger, event_name: Union[str, EventEnum]): + def __call__(self, engine: Engine, logger: TensorboardLogger, event_name: Union[str, Events]) -> None: if not isinstance(logger, TensorboardLogger): raise RuntimeError("Handler OptimizerParamsHandler works only with TensorboardLogger") @@ -371,10 +371,10 @@ class WeightsScalarHandler(BaseWeightsScalarHandler): """ - def __init__(self, model: nn.Module, reduction: Callable = torch.norm, tag: Optional[str] = None): + def __init__(self, model: nn.Module, reduction: Callable = torch.norm, tag: Optional[str] = None) -> None: super(WeightsScalarHandler, self).__init__(model, reduction, tag=tag) - def __call__(self, engine: Engine, logger: TensorboardLogger, event_name: Union[str, EventEnum]): + def __call__(self, engine: Engine, logger: TensorboardLogger, event_name: Union[str, Events]) -> None: if not isinstance(logger, TensorboardLogger): raise RuntimeError("Handler 'WeightsScalarHandler' works only with TensorboardLogger") @@ -419,7 +419,7 @@ class WeightsHistHandler(BaseWeightsHistHandler): def __init__(self, model: nn.Module, tag: Optional[str] = None): super(WeightsHistHandler, self).__init__(model, tag=tag) - def __call__(self, engine: Engine, logger: TensorboardLogger, event_name: Union[str, EventEnum]): + def __call__(self, engine: Engine, logger: TensorboardLogger, event_name: Union[str, Events]) -> None: if not isinstance(logger, TensorboardLogger): raise RuntimeError("Handler 'WeightsHistHandler' works only with TensorboardLogger") @@ -465,10 +465,10 @@ class GradsScalarHandler(BaseWeightsScalarHandler): """ - def __init__(self, model: nn.Module, reduction: Callable = torch.norm, tag: Optional[str] = None): + def __init__(self, model: nn.Module, reduction: Callable = torch.norm, tag: Optional[str] = None) -> None: super(GradsScalarHandler, self).__init__(model, reduction, tag=tag) - def __call__(self, engine: Engine, logger: TensorboardLogger, event_name: Union[str, EventEnum]): + def __call__(self, engine: Engine, logger: TensorboardLogger, event_name: Union[str, Events]) -> None: if not isinstance(logger, TensorboardLogger): raise RuntimeError("Handler 'GradsScalarHandler' works only with TensorboardLogger") @@ -509,10 +509,10 @@ class GradsHistHandler(BaseWeightsHistHandler): """ - def __init__(self, model: nn.Module, tag: Optional[str] = None): + def __init__(self, model: nn.Module, tag: Optional[str] = None) -> None: super(GradsHistHandler, self).__init__(model, tag=tag) - def __call__(self, engine: Engine, logger: TensorboardLogger, event_name: Union[str, EventEnum]): + def __call__(self, engine: Engine, logger: TensorboardLogger, event_name: Union[str, Events]) -> None: if not isinstance(logger, TensorboardLogger): raise RuntimeError("Handler 'GradsHistHandler' works only with TensorboardLogger") diff --git a/ignite/contrib/handlers/time_profilers.py b/ignite/contrib/handlers/time_profilers.py index 85046273a579..95e03f55a967 100644 --- a/ignite/contrib/handlers/time_profilers.py +++ b/ignite/contrib/handlers/time_profilers.py @@ -1,6 +1,6 @@ import functools from collections import OrderedDict -from typing import Any, Callable, Dict, Iterable, List, Mapping, Sequence, Union +from typing import Any, Callable, Dict, List, Mapping, Sequence, Tuple, Union, cast import torch @@ -42,14 +42,14 @@ def log_intermediate_results(): Events.DATALOADER_STOP_ITERATION, ] - def __init__(self): + def __init__(self) -> None: self._dataflow_timer = Timer() self._processing_timer = Timer() self._event_handlers_timer = Timer() - self.dataflow_times = None - self.processing_times = None - self.event_handlers_times = None + self.dataflow_times = torch.zeros(1) + self.processing_times = torch.zeros(1) + self.event_handlers_times = {} # type: Dict[EventEnum, torch.Tensor] self._events = [ Events.EPOCH_STARTED, @@ -79,7 +79,7 @@ def __init__(self): self._as_last_completed, ] - def _reset(self, num_epochs: int, total_num_iters: int): + def _reset(self, num_epochs: int, total_num_iters: int) -> None: self.dataflow_times = torch.zeros(total_num_iters) self.processing_times = torch.zeros(total_num_iters) self.event_handlers_times = { @@ -93,13 +93,18 @@ def _reset(self, num_epochs: int, total_num_iters: int): Events.GET_BATCH_STARTED: torch.zeros(total_num_iters), } - def _as_first_started(self, engine: Engine): + def _as_first_started(self, engine: Engine) -> None: if hasattr(engine.state.dataloader, "__len__"): - num_iters_per_epoch = len(engine.state.dataloader) + num_iters_per_epoch = len(engine.state.dataloader) # type: ignore[arg-type] else: + if engine.state.epoch_length is None: + raise ValueError( + "As epoch_length is not set, we can not use BasicTimeProfiler in this case." + "Please, set trainer.run(..., epoch_length=epoch_length) in order to fix this." + ) num_iters_per_epoch = engine.state.epoch_length - self.max_epochs = engine.state.max_epochs + self.max_epochs = cast(int, engine.state.max_epochs) self.total_num_iters = self.max_epochs * num_iters_per_epoch self._reset(self.max_epochs, self.total_num_iters) @@ -125,30 +130,30 @@ def _as_first_started(self, engine: Engine): # Let's go self._event_handlers_timer.reset() - def _as_last_started(self, engine: Engine): + def _as_last_started(self, engine: Engine) -> None: self.event_handlers_times[Events.STARTED][0] = self._event_handlers_timer.value() - def _as_first_epoch_started(self, engine: Engine): + def _as_first_epoch_started(self, engine: Engine) -> None: self._event_handlers_timer.reset() - def _as_last_epoch_started(self, engine: Engine): + def _as_last_epoch_started(self, engine: Engine) -> None: t = self._event_handlers_timer.value() e = engine.state.epoch - 1 self.event_handlers_times[Events.EPOCH_STARTED][e] = t - def _as_first_get_batch_started(self, engine: Engine): + def _as_first_get_batch_started(self, engine: Engine) -> None: self._event_handlers_timer.reset() self._dataflow_timer.reset() - def _as_last_get_batch_started(self, engine: Engine): + def _as_last_get_batch_started(self, engine: Engine) -> None: t = self._event_handlers_timer.value() i = engine.state.iteration - 1 self.event_handlers_times[Events.GET_BATCH_STARTED][i] = t - def _as_first_get_batch_completed(self, engine: Engine): + def _as_first_get_batch_completed(self, engine: Engine) -> None: self._event_handlers_timer.reset() - def _as_last_get_batch_completed(self, engine: Engine): + def _as_last_get_batch_completed(self, engine: Engine) -> None: t = self._event_handlers_timer.value() i = engine.state.iteration - 1 self.event_handlers_times[Events.GET_BATCH_COMPLETED][i] = t @@ -158,40 +163,40 @@ def _as_last_get_batch_completed(self, engine: Engine): self._dataflow_timer.reset() - def _as_first_iter_started(self, engine: Engine): + def _as_first_iter_started(self, engine: Engine) -> None: self._event_handlers_timer.reset() - def _as_last_iter_started(self, engine: Engine): + def _as_last_iter_started(self, engine: Engine) -> None: t = self._event_handlers_timer.value() i = engine.state.iteration - 1 self.event_handlers_times[Events.ITERATION_STARTED][i] = t self._processing_timer.reset() - def _as_first_iter_completed(self, engine: Engine): + def _as_first_iter_completed(self, engine: Engine) -> None: t = self._processing_timer.value() i = engine.state.iteration - 1 self.processing_times[i] = t self._event_handlers_timer.reset() - def _as_last_iter_completed(self, engine: Engine): + def _as_last_iter_completed(self, engine: Engine) -> None: t = self._event_handlers_timer.value() i = engine.state.iteration - 1 self.event_handlers_times[Events.ITERATION_COMPLETED][i] = t - def _as_first_epoch_completed(self, engine: Engine): + def _as_first_epoch_completed(self, engine: Engine) -> None: self._event_handlers_timer.reset() - def _as_last_epoch_completed(self, engine: Engine): + def _as_last_epoch_completed(self, engine: Engine) -> None: t = self._event_handlers_timer.value() e = engine.state.epoch - 1 self.event_handlers_times[Events.EPOCH_COMPLETED][e] = t - def _as_first_completed(self, engine: Engine): + def _as_first_completed(self, engine: Engine) -> None: self._event_handlers_timer.reset() - def _as_last_completed(self, engine: Engine): + def _as_last_completed(self, engine: Engine) -> None: self.event_handlers_times[Events.COMPLETED][0] = self._event_handlers_timer.value() # Remove added handlers: @@ -203,7 +208,7 @@ def _as_last_completed(self, engine: Engine): for e, m in zip(self._events, self._lmethods): engine.remove_event_handler(m, e) - def attach(self, engine: Engine): + def attach(self, engine: Engine) -> None: if not isinstance(engine, Engine): raise TypeError("Argument engine should be ignite.engine.Engine, " "but given {}".format(type(engine))) @@ -211,10 +216,12 @@ def attach(self, engine: Engine): engine._event_handlers[Events.STARTED].insert(0, (self._as_first_started, (engine,), {})) @staticmethod - def _compute_basic_stats(data: Sequence): + def _compute_basic_stats(data: torch.Tensor) -> Dict[str, Union[str, float, Tuple[Union[float], Union[float]]]]: # compute on non-zero data: data = data[data > 0] - out = [("total", torch.sum(data).item() if len(data) > 0 else "not yet triggered")] + out = [ + ("total", torch.sum(data).item() if len(data) > 0 else "not yet triggered") + ] # type: List[Tuple[str, Union[str, float, Tuple[Union[float], Union[float]]]]] if len(data) > 1: out += [ ("min/index", (torch.min(data).item(), torch.argmin(data).item())), @@ -224,7 +231,7 @@ def _compute_basic_stats(data: Sequence): ] return OrderedDict(out) - def get_results(self): + def get_results(self) -> Dict[str, Dict[str, Any]]: """ Method to fetch the aggregated profiler results after the engine is run @@ -233,23 +240,23 @@ def get_results(self): results = profiler.get_results() """ - total_eh_time = sum([(self.event_handlers_times[e]).sum() for e in Events if e not in self.events_to_ignore]) + total_eh_time = sum( + [(self.event_handlers_times[e]).sum() for e in Events if e not in self.events_to_ignore] + ) # type: Union[int, torch.Tensor] + event_handlers_stats = dict( + [ + (str(e.name).replace(".", "_"), self._compute_basic_stats(self.event_handlers_times[e])) + for e in Events + if e not in self.events_to_ignore + ] + + [("total_time", total_eh_time)] # type: ignore[list-item] + ) return OrderedDict( [ ("processing_stats", self._compute_basic_stats(self.processing_times)), ("dataflow_stats", self._compute_basic_stats(self.dataflow_times)), - ( - "event_handlers_stats", - dict( - [ - (str(e.name).replace(".", "_"), self._compute_basic_stats(self.event_handlers_times[e])) - for e in Events - if e not in self.events_to_ignore - ] - + [("total_time", total_eh_time)] - ), - ), + ("event_handlers_stats", event_handlers_stats,), ( "event_handlers_names", {str(e.name).replace(".", "_") + "_names": v for e, v in self.event_handlers_names.items()}, @@ -257,7 +264,7 @@ def get_results(self): ] ) - def write_results(self, output_path: str): + def write_results(self, output_path: str) -> None: """ Method to store the unaggregated profiling results to a csv file @@ -336,7 +343,7 @@ def write_results(self, output_path: str): results_df.to_csv(output_path, index=False) @staticmethod - def print_results(results: Union[Dict, Iterable]): + def print_results(results: Dict) -> str: """ Method to print the aggregated results from the profiler @@ -382,14 +389,14 @@ def print_results(results: Union[Dict, Iterable]): """ - def to_str(v: Union[str, tuple]): + def to_str(v: Union[str, tuple]) -> str: if isinstance(v, str): return v elif isinstance(v, tuple): return "{:.5f}/{}".format(v[0], v[1]) return "{:.5f}".format(v) - def odict_to_str(d: Mapping): + def odict_to_str(d: Mapping) -> str: out = " | ".join([to_str(v) for v in d.values()]) return out @@ -470,21 +477,21 @@ def log_intermediate_results(): EVENT_FILTER_THESHOLD_TIME = 0.0001 - def __init__(self): + def __init__(self) -> None: self._dataflow_timer = Timer() self._processing_timer = Timer() self._event_handlers_timer = Timer() - self.dataflow_times = None - self.processing_times = None - self.event_handlers_times = None + self.dataflow_times = [] # type: List[float] + self.processing_times = [] # type: List[float] + self.event_handlers_times = {} # type: Dict[EventEnum, Dict[str, List[float]]] @staticmethod def _get_callable_name(handler: Callable) -> str: # get name of the callable handler return getattr(handler, "__qualname__", handler.__class__.__name__) - def _create_wrapped_handler(self, handler: Callable, event: Events) -> Callable: + def _create_wrapped_handler(self, handler: Callable, event: EventEnum) -> Callable: @functools.wraps(handler) def _timeit_handler(*args: Any, **kwargs: Any) -> None: self._event_handlers_timer.reset() @@ -583,15 +590,17 @@ def get_results(self) -> List[List[Union[str, float]]]: ) total_eh_time = round(float(total_eh_time), 5,) - def compute_basic_stats(data: Sequence) -> List[Union[str, float]]: - data = torch.as_tensor(data, dtype=torch.float32) + def compute_basic_stats( + times: Union[Sequence, torch.Tensor] + ) -> List[Union[str, float, Tuple[Union[str, float], Union[str, float]]]]: + data = torch.as_tensor(times, dtype=torch.float32) # compute on non-zero data: data = data[data > 0] - total = round(torch.sum(data).item(), 5) if len(data) > 0 else "not triggered" - min_index = ("None", "None") - max_index = ("None", "None") - mean = "None" - std = "None" + total = round(torch.sum(data).item(), 5) if len(data) > 0 else "not triggered" # type: Union[str, float] + min_index = ("None", "None") # type: Tuple[Union[str, float], Union[str, float]] + max_index = ("None", "None") # type: Tuple[Union[str, float], Union[str, float]] + mean = "None" # type: Union[str, float] + std = "None" # type: Union[str, float] if len(data) > 0: min_index = (round(torch.min(data).item(), 5), torch.argmin(data).item()) max_index = (round(torch.max(data).item(), 5), torch.argmax(data).item()) @@ -695,8 +704,8 @@ def print_results(results: List[List[Union[str, float]]]) -> None: """ # adopted implementation of torch.autograd.profiler.build_table - handler_column_width = max([len(item[0]) for item in results]) + 4 - event_column_width = max([len(item[1]) for item in results]) + 4 + handler_column_width = max([len(item[0]) for item in results]) + 4 # type: ignore[arg-type] + event_column_width = max([len(item[1]) for item in results]) + 4 # type: ignore[arg-type] DEFAULT_COLUMN_WIDTH = 14 @@ -742,8 +751,8 @@ def append(s: str) -> None: for row in results[:-3]: # format min/idx and max/idx - row[3] = "{}/{}".format(*row[3]) - row[4] = "{}/{}".format(*row[4]) + row[3] = "{}/{}".format(*row[3]) # type: ignore[misc] + row[4] = "{}/{}".format(*row[4]) # type: ignore[misc] append(row_format.format(*row)) @@ -754,8 +763,8 @@ def append(s: str) -> None: summary_format = "{} took total {}s [min/index: {}, max/index: {}, mean: {}s, std: {}s]" for row in results[-2:]: - row[3] = "{}s/{}".format(*row[3]) - row[4] = "{}s/{}".format(*row[4]) + row[3] = "{}s/{}".format(*row[3]) # type: ignore[misc] + row[4] = "{}s/{}".format(*row[4]) # type: ignore[misc] del row[1] append(summary_format.format(*row)) print("".join(result)) diff --git a/ignite/contrib/handlers/tqdm_logger.py b/ignite/contrib/handlers/tqdm_logger.py index a0fe33f561e6..6dc7d22688c0 100644 --- a/ignite/contrib/handlers/tqdm_logger.py +++ b/ignite/contrib/handlers/tqdm_logger.py @@ -1,13 +1,12 @@ # -*- coding: utf-8 -*- import warnings -from enum import Enum -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, List, Optional, Union import torch from ignite.contrib.handlers.base_logger import BaseLogger, BaseOutputHandler from ignite.engine import Engine, Events -from ignite.engine.events import CallableEventWithFilter, EventEnum +from ignite.engine.events import CallableEventWithFilter class ProgressBar(BaseLogger): @@ -99,14 +98,14 @@ class ProgressBar(BaseLogger): Events.ITERATION_COMPLETED, Events.EPOCH_COMPLETED, Events.COMPLETED, - ] + ] # type: List[Union[Events, CallableEventWithFilter]] def __init__( self, persist: bool = False, bar_format: str = "{desc}[{n_fmt}/{total_fmt}] {percentage:3.0f}%|{bar}{postfix} [{elapsed}<{remaining}]", **tqdm_kwargs: Any - ): + ) -> None: try: from tqdm.autonotebook import tqdm @@ -122,12 +121,12 @@ def __init__( self.bar_format = bar_format self.tqdm_kwargs = tqdm_kwargs - def _reset(self, pbar_total: int): + def _reset(self, pbar_total: Optional[int]) -> None: self.pbar = self.pbar_cls( total=pbar_total, leave=self.persist, bar_format=self.bar_format, initial=1, **self.tqdm_kwargs ) - def _close(self, engine: Engine): + def _close(self, engine: Engine) -> None: if self.pbar is not None: # https://github.com/tqdm/notebook.py#L240-L250 # issue #1115 : notebook backend of tqdm checks if n < total (error or KeyboardInterrupt) @@ -138,12 +137,14 @@ def _close(self, engine: Engine): self.pbar = None @staticmethod - def _compare_lt(event1: EventEnum, event2: EventEnum): + def _compare_lt( + event1: Union[Events, CallableEventWithFilter], event2: Union[Events, CallableEventWithFilter] + ) -> bool: i1 = ProgressBar._events_order.index(event1) i2 = ProgressBar._events_order.index(event2) return i1 < i2 - def log_message(self, message: str): + def log_message(self, message: str) -> None: """ Logs a message, preserving the progress bar correct output format. @@ -154,14 +155,14 @@ def log_message(self, message: str): tqdm.write(message, file=self.tqdm_kwargs.get("file", None)) - def attach( + def attach( # type: ignore[override] self, engine: Engine, metric_names: Optional[str] = None, output_transform: Optional[Callable] = None, - event_name: Union[CallableEventWithFilter, Events] = Events.ITERATION_COMPLETED, - closing_event_name: Events = Events.EPOCH_COMPLETED, - ): + event_name: Union[Events, CallableEventWithFilter] = Events.ITERATION_COMPLETED, + closing_event_name: Union[Events, CallableEventWithFilter] = Events.EPOCH_COMPLETED, + ) -> None: """ Attaches the progress bar to an engine object. @@ -200,14 +201,16 @@ def attach( super(ProgressBar, self).attach(engine, log_handler, event_name) engine.add_event_handler(closing_event_name, self._close) - def attach_opt_params_handler(self, engine: Engine, event_name: Union[str, EventEnum], *args: Any, **kwargs: Any): + def attach_opt_params_handler( + self, engine: Engine, event_name: Union[str, Events], *args: Any, **kwargs: Any + ) -> None: """Intentionally empty""" pass - def _create_output_handler(self, *args: Any, **kwargs: Any): + def _create_output_handler(self, *args: Any, **kwargs: Any) -> "_OutputHandler": return _OutputHandler(*args, **kwargs) - def _create_opt_params_handler(self, *args: Any, **kwargs: Any): + def _create_opt_params_handler(self, *args: Any, **kwargs: Any) -> Callable: """Intentionally empty""" pass @@ -232,10 +235,10 @@ class _OutputHandler(BaseOutputHandler): def __init__( self, description: str, - metric_names: Optional[str] = None, + metric_names: Optional[Union[str, List[str]]] = None, output_transform: Optional[Callable] = None, - closing_event_name: EventEnum = Events.EPOCH_COMPLETED, - ): + closing_event_name: Union[Events, CallableEventWithFilter] = Events.EPOCH_COMPLETED, + ) -> None: if metric_names is None and output_transform is None: # This helps to avoid 'Either metric_names or output_transform should be defined' of BaseOutputHandler metric_names = [] @@ -243,14 +246,14 @@ def __init__( self.closing_event_name = closing_event_name @staticmethod - def get_max_number_events(event_name: Union[CallableEventWithFilter, Enum], engine: Engine): + def get_max_number_events(event_name: Union[str, Events, CallableEventWithFilter], engine: Engine) -> Optional[int]: if event_name in (Events.ITERATION_STARTED, Events.ITERATION_COMPLETED): return engine.state.epoch_length if event_name in (Events.EPOCH_STARTED, Events.EPOCH_COMPLETED): return engine.state.max_epochs return 1 - def __call__(self, engine: Engine, logger: ProgressBar, event_name: Union[CallableEventWithFilter, Enum]): + def __call__(self, engine: Engine, logger: ProgressBar, event_name: Union[str, Events]) -> None: pbar_total = self.get_max_number_events(event_name, engine) if logger.pbar is None: @@ -261,10 +264,10 @@ def __call__(self, engine: Engine, logger: ProgressBar, event_name: Union[Callab desc = self.tag or default_desc max_num_of_closing_events = self.get_max_number_events(self.closing_event_name, engine) - if max_num_of_closing_events > 1: + if max_num_of_closing_events and max_num_of_closing_events > 1: global_step = engine.state.get_event_attrib_value(self.closing_event_name) desc += " [{}/{}]".format(global_step, max_num_of_closing_events) - logger.pbar.set_description(desc) + logger.pbar.set_description(desc) # type: ignore[attr-defined] metrics = self._setup_output_metrics(engine) @@ -283,9 +286,9 @@ def __call__(self, engine: Engine, logger: ProgressBar, event_name: Union[Callab rendered_metrics[key] = value if rendered_metrics: - logger.pbar.set_postfix(**rendered_metrics) + logger.pbar.set_postfix(**rendered_metrics) # type: ignore[attr-defined] global_step = engine.state.get_event_attrib_value(event_name) if pbar_total is not None: global_step = (global_step - 1) % pbar_total + 1 - logger.pbar.update(global_step - logger.pbar.n) + logger.pbar.update(global_step - logger.pbar.n) # type: ignore[attr-defined] diff --git a/ignite/contrib/handlers/trains_logger.py b/ignite/contrib/handlers/trains_logger.py index 49bc510542f9..7c00da41c87d 100644 --- a/ignite/contrib/handlers/trains_logger.py +++ b/ignite/contrib/handlers/trains_logger.py @@ -5,7 +5,7 @@ from collections import defaultdict from datetime import datetime from enum import Enum -from typing import Any, Callable, List, Mapping, Optional, Type +from typing import Any, Callable, DefaultDict, List, Mapping, Optional, Tuple, Type, Union import torch from torch.nn import Module @@ -19,7 +19,7 @@ BaseWeightsHistHandler, BaseWeightsScalarHandler, ) -from ignite.engine import Engine, EventEnum +from ignite.engine import Engine, Events from ignite.handlers import global_step_from_engine from ignite.handlers.checkpoint import DiskSaver @@ -119,7 +119,7 @@ class TrainsLogger(BaseLogger): """ - def __init__(self, *_, **kwargs: Any): + def __init__(self, *_: Any, **kwargs: Any) -> None: try: from trains import Task from trains.binding.frameworks.tensorflow_bind import WeightsGradientHistHelper @@ -135,15 +135,15 @@ def __init__(self, *_, **kwargs: Any): warnings.warn("TrainsSaver: running in bypass mode") class _Stub(object): - def __call__(self, *_, **__): + def __call__(self, *_: Any, **__: Any) -> "_Stub": return self - def __getattr__(self, attr): + def __getattr__(self, attr: str) -> "_Stub": if attr in ("name", "id"): - return "" + return "" # type: ignore[return-value] return self - def __setattr__(self, attr, val): + def __setattr__(self, attr: str, val: Any) -> None: pass self._task = _Stub() @@ -183,13 +183,13 @@ def bypass_mode(cls) -> bool: """ return getattr(cls, "_bypass", bool(os.environ.get("CI"))) - def close(self): + def close(self) -> None: self.trains_logger.flush() - def _create_output_handler(self, *args: Any, **kwargs: Any): + def _create_output_handler(self, *args: Any, **kwargs: Any) -> "OutputHandler": return OutputHandler(*args, **kwargs) - def _create_opt_params_handler(self, *args: Any, **kwargs: Any): + def _create_opt_params_handler(self, *args: Any, **kwargs: Any) -> "OptimizerParamsHandler": return OptimizerParamsHandler(*args, **kwargs) @@ -293,17 +293,17 @@ def __init__( metric_names: Optional[List[str]] = None, output_transform: Optional[Callable] = None, global_step_transform: Optional[Callable] = None, - ): + ) -> None: super(OutputHandler, self).__init__(tag, metric_names, output_transform, global_step_transform) - def __call__(self, engine: Engine, logger: TrainsLogger, event_name: EventEnum): + def __call__(self, engine: Engine, logger: TrainsLogger, event_name: Union[str, Events]) -> None: if not isinstance(logger, TrainsLogger): raise RuntimeError("Handler OutputHandler works only with TrainsLogger") metrics = self._setup_output_metrics(engine) - global_step = self.global_step_transform(engine, event_name) + global_step = self.global_step_transform(engine, event_name) # type: ignore[misc] if not isinstance(global_step, int): raise TypeError( @@ -359,10 +359,10 @@ class OptimizerParamsHandler(BaseOptimizerParamsHandler): tag (str, optional): common title for all produced plots. For example, "generator" """ - def __init__(self, optimizer: Optimizer, param_name: str = "lr", tag: Optional[str] = None): + def __init__(self, optimizer: Optimizer, param_name: str = "lr", tag: Optional[str] = None) -> None: super(OptimizerParamsHandler, self).__init__(optimizer, param_name, tag) - def __call__(self, engine: Engine, logger: TrainsLogger, event_name: EventEnum): + def __call__(self, engine: Engine, logger: TrainsLogger, event_name: Union[str, Events]) -> None: if not isinstance(logger, TrainsLogger): raise RuntimeError("Handler OptimizerParamsHandler works only with TrainsLogger") @@ -410,10 +410,10 @@ class WeightsScalarHandler(BaseWeightsScalarHandler): """ - def __init__(self, model: Module, reduction: Callable = torch.norm, tag: Optional[str] = None): + def __init__(self, model: Module, reduction: Callable = torch.norm, tag: Optional[str] = None) -> None: super(WeightsScalarHandler, self).__init__(model, reduction, tag=tag) - def __call__(self, engine: Engine, logger: TrainsLogger, event_name: EventEnum): + def __call__(self, engine: Engine, logger: TrainsLogger, event_name: Union[str, Events]) -> None: if not isinstance(logger, TrainsLogger): raise RuntimeError("Handler WeightsScalarHandler works only with TrainsLogger") @@ -462,10 +462,10 @@ class WeightsHistHandler(BaseWeightsHistHandler): """ - def __init__(self, model: Module, tag: Optional[str] = None): + def __init__(self, model: Module, tag: Optional[str] = None) -> None: super(WeightsHistHandler, self).__init__(model, tag=tag) - def __call__(self, engine: Engine, logger: TrainsLogger, event_name: EventEnum): + def __call__(self, engine: Engine, logger: TrainsLogger, event_name: Union[str, Events]) -> None: if not isinstance(logger, TrainsLogger): raise RuntimeError("Handler 'WeightsHistHandler' works only with TrainsLogger") @@ -517,10 +517,10 @@ class GradsScalarHandler(BaseWeightsScalarHandler): """ - def __init__(self, model: Module, reduction: Callable = torch.norm, tag: Optional[str] = None): + def __init__(self, model: Module, reduction: Callable = torch.norm, tag: Optional[str] = None) -> None: super(GradsScalarHandler, self).__init__(model, reduction, tag=tag) - def __call__(self, engine: Engine, logger: TrainsLogger, event_name: EventEnum): + def __call__(self, engine: Engine, logger: TrainsLogger, event_name: Union[str, Events]) -> None: if not isinstance(logger, TrainsLogger): raise RuntimeError("Handler GradsScalarHandler works only with TrainsLogger") @@ -568,10 +568,10 @@ class GradsHistHandler(BaseWeightsHistHandler): """ - def __init__(self, model: Module, tag: Optional[str] = None): + def __init__(self, model: Module, tag: Optional[str] = None) -> None: super(GradsHistHandler, self).__init__(model, tag=tag) - def __call__(self, engine: Engine, logger: TrainsLogger, event_name: EventEnum): + def __call__(self, engine: Engine, logger: TrainsLogger, event_name: Union[str, Events]) -> None: if not isinstance(logger, TrainsLogger): raise RuntimeError("Handler 'GradsHistHandler' works only with TrainsLogger") @@ -633,8 +633,13 @@ class TrainsSaver(DiskSaver): """ def __init__( - self, logger: TrainsLogger = None, output_uri: str = None, dirname: str = None, *args: Any, **kwargs: Any - ): + self, + logger: Optional[TrainsLogger] = None, + output_uri: Optional[str] = None, + dirname: Optional[str] = None, + *args: Any, + **kwargs: Any + ) -> None: self._setup_check_trains(logger, output_uri) @@ -645,7 +650,7 @@ def __init__( prefix="ignite_checkpoints_{}".format(datetime.now().strftime("%Y_%m_%d_%H_%M_%S_")) ) if idist.get_world_size() > 1: - dirname = idist.all_gather(dirname)[0] + dirname = idist.all_gather(dirname)[0] # type: ignore[index, assignment] warnings.warn("TrainsSaver created a temporary checkpoints directory: {}".format(dirname)) idist.barrier() @@ -654,12 +659,12 @@ def __init__( if "atomic" not in kwargs: kwargs["atomic"] = False - self._checkpoint_slots = defaultdict(list) + self._checkpoint_slots = defaultdict(list) # type: DefaultDict[Union[str, Tuple[str, str]], List[Any]] - super(TrainsSaver, self).__init__(dirname=dirname, *args, **kwargs) + super(TrainsSaver, self).__init__(dirname=dirname, *args, **kwargs) # type: ignore[misc] @idist.one_rank_only() - def _setup_check_trains(self, logger: TrainsLogger, output_uri: str): + def _setup_check_trains(self, logger: TrainsLogger, output_uri: str) -> None: try: from trains import Task except ImportError: @@ -690,7 +695,7 @@ def __init__( filename: str, basename: str, metadata: Optional[Mapping] = None, - ): + ) -> None: self._callback_type = callback_type self._slots = slots self._checkpoint_key = str(checkpoint_key) @@ -698,8 +703,8 @@ def __init__( self._basename = basename self._metadata = metadata - def pre_callback(self, action: str, model_info: Any): - if action != self._callback_type.save: + def pre_callback(self, action: str, model_info: Any) -> Any: + if action != self._callback_type.save: # type: ignore[attr-defined] return model_info try: @@ -713,8 +718,8 @@ def pre_callback(self, action: str, model_info: Any): model_info.local_model_id = "{}:{}".format(self._checkpoint_key, model_info.upload_filename) return model_info - def post_callback(self, action: str, model_info: Any): - if action != self._callback_type.save: + def post_callback(self, action: str, model_info: Any) -> Any: + if action != self._callback_type.save: # type: ignore[attr-defined] return model_info model_info.model.name = "{}: {}".format(model_info.task.name, self._filename) @@ -742,7 +747,7 @@ def __call__(self, checkpoint: Mapping, filename: str, metadata: Optional[Mappin ) try: - basename = metadata["basename"] + basename = metadata["basename"] # type: ignore[index] except (TypeError, KeyError): warnings.warn("Checkpoint metadata missing or basename cannot be found") basename = "checkpoint" @@ -786,6 +791,8 @@ def get_local_copy(self, filename: str) -> Optional[str]: return artifact.get_local_copy() self._task.get_logger().report_text("Can not find artifact {}".format(filename)) + return None + @idist.one_rank_only() def remove(self, filename: str) -> None: super(TrainsSaver, self).remove(filename) diff --git a/ignite/contrib/handlers/visdom_logger.py b/ignite/contrib/handlers/visdom_logger.py index 010ab3b05493..7b737c6bea85 100644 --- a/ignite/contrib/handlers/visdom_logger.py +++ b/ignite/contrib/handlers/visdom_logger.py @@ -1,7 +1,7 @@ import numbers import os import warnings -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union, cast import torch import torch.nn as nn @@ -13,7 +13,7 @@ BaseOutputHandler, BaseWeightsScalarHandler, ) -from ignite.engine import Engine +from ignite.engine import Engine, Events from ignite.handlers import global_step_from_engine __all__ = [ @@ -170,7 +170,7 @@ def __init__( ) if server is None: - server = os.environ.get("VISDOM_SERVER_URL", "localhost") + server = cast(str, os.environ.get("VISDOM_SERVER_URL", "localhost")) if port is None: port = int(os.environ.get("VISDOM_PORT", 8097)) @@ -185,37 +185,39 @@ def __init__( self.vis = visdom.Visdom(server=server, port=port, raise_exceptions=raise_exceptions, **kwargs) - if not self.vis.offline and not self.vis.check_connection(): + if not self.vis.offline and not self.vis.check_connection(): # type: ignore[attr-defined] raise RuntimeError( "Failed to connect to Visdom server at {}. Did you run python -m visdom.server ?".format(server) ) - self.executor = _DummyExecutor() + self.executor = _DummyExecutor() # type: Union[_DummyExecutor, "ThreadPoolExecutor"] if num_workers > 0: from concurrent.futures import ThreadPoolExecutor self.executor = ThreadPoolExecutor(max_workers=num_workers) - def _save(self): - self.vis.save([self.vis.env]) + def _save(self) -> None: + self.vis.save([self.vis.env]) # type: ignore[attr-defined] - def close(self): + def close(self) -> None: self.executor.shutdown() - self.vis = None + self.vis.close() - def _create_output_handler(self, *args: Any, **kwargs: Any): + def _create_output_handler(self, *args: Any, **kwargs: Any) -> "OutputHandler": return OutputHandler(*args, **kwargs) - def _create_opt_params_handler(self, *args: Any, **kwargs: Any): + def _create_opt_params_handler(self, *args: Any, **kwargs: Any) -> "OptimizerParamsHandler": return OptimizerParamsHandler(*args, **kwargs) class _BaseVisDrawer: - def __init__(self, show_legend: bool = False): - self.windows = {} + def __init__(self, show_legend: bool = False) -> None: + self.windows = {} # type: Dict[str, Any] self.show_legend = show_legend - def add_scalar(self, logger: VisdomLogger, k: str, v: Union[str, float], event_name: Any, global_step: int): + def add_scalar( + self, logger: VisdomLogger, k: str, v: Union[str, float, torch.Tensor], event_name: Any, global_step: int + ) -> None: """ Helper method to log a scalar with VisdomLogger. @@ -240,7 +242,7 @@ def add_scalar(self, logger: VisdomLogger, k: str, v: Union[str, float], event_n kwargs = { "X": [global_step], "Y": [v], - "env": logger.vis.env, + "env": logger.vis.env, # type: ignore[attr-defined] "win": self.windows[k]["win"], "update": update, "opts": self.windows[k]["opts"], @@ -346,18 +348,18 @@ def __init__( output_transform: Optional[Callable] = None, global_step_transform: Optional[Callable] = None, show_legend: bool = False, - ): + ) -> None: super(OutputHandler, self).__init__(tag, metric_names, output_transform, global_step_transform) _BaseVisDrawer.__init__(self, show_legend=show_legend) - def __call__(self, engine: Engine, logger: VisdomLogger, event_name: Any): + def __call__(self, engine: Engine, logger: VisdomLogger, event_name: Union[str, Events]) -> None: if not isinstance(logger, VisdomLogger): raise RuntimeError("Handler 'OutputHandler' works only with VisdomLogger") metrics = self._setup_output_metrics(engine) - global_step = self.global_step_transform(engine, event_name) + global_step = self.global_step_transform(engine, event_name) # type: ignore[misc] if not isinstance(global_step, int): raise TypeError( @@ -367,13 +369,13 @@ def __call__(self, engine: Engine, logger: VisdomLogger, event_name: Any): for key, value in metrics.items(): - values = [] + values = [] # type: List[Union[float, torch.Tensor]] keys = [] if isinstance(value, numbers.Number) or isinstance(value, torch.Tensor) and value.ndimension() == 0: - values.append(value) + values.append(value) # type: ignore[arg-type] keys.append(key) elif isinstance(value, torch.Tensor) and value.ndimension() == 1: - values = value + values = value # type: ignore[assignment] keys = ["{}/{}".format(key, i) for i in range(len(value))] else: warnings.warn("VisdomLogger output_handler can not log " "metrics value type {}".format(type(value))) @@ -420,11 +422,11 @@ class OptimizerParamsHandler(BaseOptimizerParamsHandler, _BaseVisDrawer): def __init__( self, optimizer: Optimizer, param_name: str = "lr", tag: Optional[str] = None, show_legend: bool = False, - ): + ) -> None: super(OptimizerParamsHandler, self).__init__(optimizer, param_name, tag) _BaseVisDrawer.__init__(self, show_legend=show_legend) - def __call__(self, engine: Engine, logger: VisdomLogger, event_name: Any): + def __call__(self, engine: Engine, logger: VisdomLogger, event_name: Union[str, Events]) -> None: if not isinstance(logger, VisdomLogger): raise RuntimeError("Handler OptimizerParamsHandler works only with VisdomLogger") @@ -471,11 +473,11 @@ class WeightsScalarHandler(BaseWeightsScalarHandler, _BaseVisDrawer): def __init__( self, model: nn.Module, reduction: Callable = torch.norm, tag: Optional[str] = None, show_legend: bool = False, - ): + ) -> None: super(WeightsScalarHandler, self).__init__(model, reduction, tag=tag) _BaseVisDrawer.__init__(self, show_legend=show_legend) - def __call__(self, engine: Engine, logger: VisdomLogger, event_name: Any): + def __call__(self, engine: Engine, logger: VisdomLogger, event_name: Union[str, Events]) -> None: if not isinstance(logger, VisdomLogger): raise RuntimeError("Handler 'WeightsScalarHandler' works only with VisdomLogger") @@ -522,11 +524,11 @@ class GradsScalarHandler(BaseWeightsScalarHandler, _BaseVisDrawer): def __init__( self, model: nn.Module, reduction: Callable = torch.norm, tag: Optional[str] = None, show_legend: bool = False, - ): + ) -> None: super(GradsScalarHandler, self).__init__(model, reduction, tag) _BaseVisDrawer.__init__(self, show_legend=show_legend) - def __call__(self, engine: Engine, logger: VisdomLogger, event_name: Any): + def __call__(self, engine: Engine, logger: VisdomLogger, event_name: Union[str, Events]) -> None: if not isinstance(logger, VisdomLogger): raise RuntimeError("Handler 'GradsScalarHandler' works only with VisdomLogger") @@ -543,17 +545,17 @@ def __call__(self, engine: Engine, logger: VisdomLogger, event_name: Any): class _DummyExecutor: class _DummyFuture: - def __init__(self, result: Any): + def __init__(self, result: Any) -> None: self._output = result - def result(self): + def result(self) -> Any: return self._output - def __init__(self, *args: Any, **kwargs: Any): + def __init__(self, *args: Any, **kwargs: Any) -> None: pass - def submit(self, fn: Callable, **kwargs: Any): + def submit(self, fn: Callable, **kwargs: Any) -> "_DummyFuture": return _DummyExecutor._DummyFuture(fn(**kwargs)) - def shutdown(self, *args: Any, **kwargs: Any): + def shutdown(self, *args: Any, **kwargs: Any) -> None: pass diff --git a/ignite/contrib/handlers/wandb_logger.py b/ignite/contrib/handlers/wandb_logger.py index de1ac2e1a4a0..a101f4d1f56c 100644 --- a/ignite/contrib/handlers/wandb_logger.py +++ b/ignite/contrib/handlers/wandb_logger.py @@ -1,10 +1,9 @@ -from enum import Enum from typing import Any, Callable, List, Optional, Union from torch.optim import Optimizer from ignite.contrib.handlers.base_logger import BaseLogger, BaseOptimizerParamsHandler, BaseOutputHandler -from ignite.engine import CallableEventWithFilter, Engine +from ignite.engine import Engine, Events from ignite.handlers import global_step_from_engine __all__ = ["WandBLogger", "OutputHandler", "OptimizerParamsHandler", "global_step_from_engine"] @@ -117,7 +116,7 @@ def score_function(engine): evaluator.add_event_handler(Events.COMPLETED, model_checkpoint, {'model': model}) """ - def __init__(self, *args: Any, **kwargs: Any): + def __init__(self, *args: Any, **kwargs: Any) -> None: try: import wandb @@ -130,13 +129,13 @@ def __init__(self, *args: Any, **kwargs: Any): if kwargs.get("init", True): wandb.init(*args, **kwargs) - def __getattr__(self, attr: Any): - return getattr(self._wandb, attr) + def __getattr__(self, attr: Any) -> Any: + return getattr(self._wandb, attr) # type: ignore[misc] - def _create_output_handler(self, *args: Any, **kwargs: Any): + def _create_output_handler(self, *args: Any, **kwargs: Any) -> "OutputHandler": return OutputHandler(*args, **kwargs) - def _create_opt_params_handler(self, *args: Any, **kwargs: Any): + def _create_opt_params_handler(self, *args: Any, **kwargs: Any) -> "OptimizerParamsHandler": return OptimizerParamsHandler(*args, **kwargs) @@ -252,16 +251,16 @@ def __init__( output_transform: Optional[Callable] = None, global_step_transform: Optional[Callable] = None, sync: Optional[bool] = None, - ): + ) -> None: super().__init__(tag, metric_names, output_transform, global_step_transform) self.sync = sync - def __call__(self, engine: Engine, logger: WandBLogger, event_name: Union[CallableEventWithFilter, Enum]): + def __call__(self, engine: Engine, logger: WandBLogger, event_name: Union[str, Events]) -> None: if not isinstance(logger, WandBLogger): raise RuntimeError("Handler '{}' works only with WandBLogger.".format(self.__class__.__name__)) - global_step = self.global_step_transform(engine, event_name) + global_step = self.global_step_transform(engine, event_name) # type: ignore[misc] if not isinstance(global_step, int): raise TypeError( "global_step must be int, got {}." @@ -319,11 +318,11 @@ class OptimizerParamsHandler(BaseOptimizerParamsHandler): def __init__( self, optimizer: Optimizer, param_name: str = "lr", tag: Optional[str] = None, sync: Optional[bool] = None, - ): + ) -> None: super(OptimizerParamsHandler, self).__init__(optimizer, param_name, tag) self.sync = sync - def __call__(self, engine: Engine, logger: WandBLogger, event_name: Union[CallableEventWithFilter, Enum]): + def __call__(self, engine: Engine, logger: WandBLogger, event_name: Union[str, Events]) -> None: if not isinstance(logger, WandBLogger): raise RuntimeError("Handler OptimizerParamsHandler works only with WandBLogger") diff --git a/ignite/engine/events.py b/ignite/engine/events.py index cd8f14da7d62..e05f477b2a82 100644 --- a/ignite/engine/events.py +++ b/ignite/engine/events.py @@ -362,7 +362,7 @@ class State: Events.EPOCH_COMPLETED: "epoch", Events.STARTED: "epoch", Events.COMPLETED: "epoch", - } + } # type: Dict[Union[str, "Events", "CallableEventWithFilter"], str] def __init__(self, **kwargs: Any) -> None: self.iteration = 0 @@ -390,7 +390,7 @@ def _update_attrs(self) -> None: if not hasattr(self, value): setattr(self, value, 0) - def get_event_attrib_value(self, event_name: Events) -> int: + def get_event_attrib_value(self, event_name: Union[str, Events, CallableEventWithFilter]) -> int: if event_name not in State.event_to_attr: raise RuntimeError("Unknown event name '{}'".format(event_name)) return getattr(self, State.event_to_attr[event_name]) diff --git a/mypy.ini b/mypy.ini index 62b36053b343..24487a2a59e3 100644 --- a/mypy.ini +++ b/mypy.ini @@ -24,21 +24,44 @@ warn_unreachable = False warn_unused_configs = True warn_unused_ignores = True -[mypy-ignite.contrib.handlers.*] - +[mypy-ignite.contrib.handlers.param_scheduler] ignore_errors = True [mypy-horovod.*] ignore_missing_imports = True +[mypy-matplotlib.*] +ignore_missing_imports = True + +[mypy-mlflow.*] +ignore_missing_imports = True + +[mypy-neptune.*] +ignore_missing_imports = True + [mypy-numpy.*] ignore_missing_imports = True +[mypy-pandas.*] +ignore_missing_imports = True + [mypy-sklearn.*] ignore_missing_imports = True +[mypy-polyaxon_client.*] +ignore_missing_imports = True + [mypy-pynvml.*] ignore_missing_imports = True +[mypy-tensorboardX.*] +ignore_missing_imports = True + [mypy-torch_xla.*] ignore_missing_imports = True + +[mypy-trains.*] +ignore_missing_imports = True + +[mypy-tqdm.*] +ignore_missing_imports = True