From 2d2797b695b9c9af3e9d9e17e69bbca5ac65c8ca Mon Sep 17 00:00:00 2001 From: gruebel Date: Wed, 2 Dec 2020 22:54:33 +0100 Subject: [PATCH 1/7] Activate mypy for contrib.handlers --- ignite/contrib/handlers/base_logger.py | 31 +-- ignite/contrib/handlers/lr_finder.py | 48 ++--- ignite/contrib/handlers/mlflow_logger.py | 28 +-- ignite/contrib/handlers/neptune_logger.py | 32 +-- ignite/contrib/handlers/param_scheduler.py | 202 ++++++++++-------- ignite/contrib/handlers/polyaxon_logger.py | 24 +-- ignite/contrib/handlers/stores.py | 12 +- ignite/contrib/handlers/tensorboard_logger.py | 36 ++-- ignite/contrib/handlers/time_profilers.py | 130 +++++------ ignite/contrib/handlers/tqdm_logger.py | 49 ++--- ignite/contrib/handlers/trains_logger.py | 77 ++++--- ignite/contrib/handlers/visdom_logger.py | 66 +++--- ignite/contrib/handlers/wandb_logger.py | 23 +- ignite/engine/events.py | 4 +- mypy.ini | 26 ++- 15 files changed, 426 insertions(+), 362 deletions(-) diff --git a/ignite/contrib/handlers/base_logger.py b/ignite/contrib/handlers/base_logger.py index 1d983ee59415..92a6609360ed 100644 --- a/ignite/contrib/handlers/base_logger.py +++ b/ignite/contrib/handlers/base_logger.py @@ -1,18 +1,19 @@ import numbers import warnings from abc import ABCMeta, abstractmethod -from typing import Any, Callable, List, Optional, Sequence, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Union import torch import torch.nn as nn from torch.optim import Optimizer -from ignite.engine import Engine, State +from ignite.engine import Engine, Events, State +from ignite.engine.events import RemovableEventHandle class BaseHandler(metaclass=ABCMeta): @abstractmethod - def __call__(self, engine, logger, event_name): + def __call__(self, engine: Engine, logger: Any, event_name: Union[str, Events]) -> None: pass @@ -68,7 +69,7 @@ def __init__( if global_step_transform is None: - def global_step_transform(engine, event_name): + def global_step_transform(engine: Engine, event_name: Union[str, Events]) -> int: return engine.state.get_event_attrib_value(event_name) self.tag = tag @@ -76,7 +77,7 @@ def global_step_transform(engine, event_name): self.output_transform = output_transform self.global_step_transform = global_step_transform - def _setup_output_metrics(self, engine: Engine): + def _setup_output_metrics(self, engine: Engine) -> Dict[str, Any]: """Helper method to setup metrics to log """ metrics = {} @@ -108,14 +109,14 @@ class BaseWeightsScalarHandler(BaseHandler): Helper handler to log model's weights as scalars. """ - def __init__(self, model: nn.Module, reduction: Callable = torch.norm, tag: Optional[str] = None): + def __init__(self, model: nn.Module, reduction: Callable = torch.norm, tag: Optional[str] = None) -> None: if not isinstance(model, torch.nn.Module): raise TypeError("Argument model should be of type torch.nn.Module, " "but given {}".format(type(model))) if not callable(reduction): raise TypeError("Argument reduction should be callable, " "but given {}".format(type(reduction))) - def _is_0D_tensor(t: torch.Tensor): + def _is_0D_tensor(t: torch.Tensor) -> bool: return isinstance(t, torch.Tensor) and t.ndimension() == 0 # Test reduction function on a tensor @@ -147,7 +148,7 @@ class BaseLogger(metaclass=ABCMeta): """ - def attach(self, engine: Engine, log_handler: Callable, event_name: Any): + def attach(self, engine: Engine, log_handler: Callable, event_name: Union[str, Events]) -> RemovableEventHandle: """Attach the logger to the engine and execute `log_handler` function at `event_name` events. Args: @@ -167,7 +168,7 @@ def attach(self, engine: Engine, log_handler: Callable, event_name: Any): return engine.add_event_handler(event_name, log_handler, self, name) - def attach_output_handler(self, engine: Engine, event_name: Any, *args: Any, **kwargs: Any): + def attach_output_handler(self, engine: Engine, event_name: Any, *args: Any, **kwargs: Any) -> RemovableEventHandle: """Shortcut method to attach `OutputHandler` to the logger. Args: @@ -183,7 +184,7 @@ def attach_output_handler(self, engine: Engine, event_name: Any, *args: Any, **k """ return self.attach(engine, self._create_output_handler(*args, **kwargs), event_name=event_name) - def attach_opt_params_handler(self, engine: Engine, event_name: Any, *args: Any, **kwargs: Any): + def attach_opt_params_handler(self, engine: Engine, event_name: Any, *args: Any, **kwargs: Any) -> None: """Shortcut method to attach `OptimizerParamsHandler` to the logger. Args: @@ -200,18 +201,18 @@ def attach_opt_params_handler(self, engine: Engine, event_name: Any, *args: Any, self.attach(engine, self._create_opt_params_handler(*args, **kwargs), event_name=event_name) @abstractmethod - def _create_output_handler(self, engine: Engine, *args: Any, **kwargs: Any): + def _create_output_handler(self, engine: Engine, *args: Any, **kwargs: Any) -> Callable: pass @abstractmethod - def _create_opt_params_handler(self, *args: Any, **kwargs: Any): + def _create_opt_params_handler(self, *args: Any, **kwargs: Any) -> Callable: pass - def __enter__(self): + def __enter__(self) -> "BaseLogger": return self - def __exit__(self, type, value, traceback): + def __exit__(self, type: Any, value: Any, traceback: Any) -> None: self.close() - def close(self): + def close(self) -> None: pass diff --git a/ignite/contrib/handlers/lr_finder.py b/ignite/contrib/handlers/lr_finder.py index 149a32e60cdf..2ad702171b01 100644 --- a/ignite/contrib/handlers/lr_finder.py +++ b/ignite/contrib/handlers/lr_finder.py @@ -4,7 +4,7 @@ import tempfile import warnings from pathlib import Path -from typing import Callable, Mapping, Optional +from typing import Any, Callable, Dict, List, Mapping, Optional, Union import torch from torch.optim import Optimizer @@ -71,11 +71,11 @@ class FastaiLRFinder: fastai/lr_find: https://github.com/fastai/fastai """ - def __init__(self): + def __init__(self) -> None: self._diverge_flag = False - self._history = None + self._history = {} # type: Dict[str, List[Any]] self._best_loss = None - self._lr_schedule = None + self._lr_schedule = None # type: Optional[Union[LRScheduler, PiecewiseLinear]] self.logger = logging.getLogger(__name__) def _run( @@ -88,7 +88,7 @@ def _run( step_mode: str, smooth_f: float, diverge_th: float, - ): + ) -> None: self._history = {"lr": [], "loss": []} self._best_loss = None @@ -98,7 +98,7 @@ def _run( if num_iter is None: num_iter = trainer.state.epoch_length * trainer.state.max_epochs else: - max_iter = trainer.state.epoch_length * trainer.state.max_epochs + max_iter = trainer.state.epoch_length * trainer.state.max_epochs # type: ignore[operator] if num_iter > max_iter: warnings.warn( "Desired num_iter {} is unreachable with the current run setup of {} iteration " @@ -127,16 +127,16 @@ def _run( if not trainer.has_event_handler(self._lr_schedule): trainer.add_event_handler(Events.ITERATION_COMPLETED, self._lr_schedule, num_iter) - def _reset(self, trainer: Engine): + def _reset(self, trainer: Engine) -> None: self.logger.debug("Completed LR finder run") - trainer.remove_event_handler(self._lr_schedule, Events.ITERATION_COMPLETED) + trainer.remove_event_handler(self._lr_schedule, Events.ITERATION_COMPLETED) # type: ignore[arg-type] trainer.remove_event_handler(self._log_lr_and_loss, Events.ITERATION_COMPLETED) trainer.remove_event_handler(self._reached_num_iterations, Events.ITERATION_COMPLETED) - def _log_lr_and_loss(self, trainer: Engine, output_transform: Callable, smooth_f: float, diverge_th: float): + def _log_lr_and_loss(self, trainer: Engine, output_transform: Callable, smooth_f: float, diverge_th: float) -> None: output = trainer.state.output loss = output_transform(output) - lr = self._lr_schedule.get_param() + lr = self._lr_schedule.get_param() # type: ignore[union-attr] self._history["lr"].append(lr) if trainer.state.iteration == 1: self._best_loss = loss @@ -148,16 +148,16 @@ def _log_lr_and_loss(self, trainer: Engine, output_transform: Callable, smooth_f self._history["loss"].append(loss) # Check if the loss has diverged; if it has, stop the trainer - if self._history["loss"][-1] > diverge_th * self._best_loss: + if self._history["loss"][-1] > diverge_th * self._best_loss: # type: ignore[operator] self._diverge_flag = True self.logger.info("Stopping early, the loss has diverged") trainer.terminate() - def _reached_num_iterations(self, trainer: Engine, num_iter: int): + def _reached_num_iterations(self, trainer: Engine, num_iter: int) -> None: if trainer.state.iteration > num_iter: trainer.terminate() - def _warning(self, _): + def _warning(self, _: Any) -> None: if not self._diverge_flag: warnings.warn( "Run completed without loss diverging, increase end_lr, decrease diverge_th or look" @@ -165,7 +165,7 @@ def _warning(self, _): UserWarning, ) - def _detach(self, trainer: Engine): + def _detach(self, trainer: Engine) -> None: """ Detaches lr_finder from trainer. @@ -180,13 +180,13 @@ def _detach(self, trainer: Engine): if trainer.has_event_handler(self._reset, Events.COMPLETED): trainer.remove_event_handler(self._reset, Events.COMPLETED) - def get_results(self): + def get_results(self) -> Dict[str, List[Any]]: """ Returns: dictionary with loss and lr logs fromm the previous run """ return self._history - def plot(self, skip_start: int = 10, skip_end: int = 5, log_lr: bool = True): + def plot(self, skip_start: int = 10, skip_end: int = 5, log_lr: bool = True) -> None: """Plots the learning rate range test. This method requires `matplotlib` package to be installed: @@ -211,7 +211,7 @@ def plot(self, skip_start: int = 10, skip_end: int = 5, log_lr: bool = True): "Please install it with command: \n pip install matplotlib" ) - if self._history is None: + if not self._history: raise RuntimeError("learning rate finder didn't run yet so results can't be plotted") if skip_start < 0: @@ -239,11 +239,11 @@ def plot(self, skip_start: int = 10, skip_end: int = 5, log_lr: bool = True): plt.ylabel("Loss") plt.show() - def lr_suggestion(self): + def lr_suggestion(self) -> Any: """ Returns: learning rate at the minimum numerical gradient """ - if self._history is None: + if not self._history: raise RuntimeError("learning rate finder didn't run yet so lr_suggestion can't be returned") loss = self._history["loss"] grads = torch.tensor([loss[i] - loss[i - 1] for i in range(1, len(loss))]) @@ -261,7 +261,7 @@ def attach( step_mode: str = "exp", smooth_f: float = 0.05, diverge_th: float = 5.0, - ): + ) -> Any: """Attaches lr_finder to a given trainer. It also resets model and optimizer at the end of the run. Usage: @@ -372,12 +372,12 @@ class _ExponentialLR(_LRScheduler): """ - def __init__(self, optimizer: Optimizer, end_lr: float, num_iter: int, last_epoch: int = -1): + def __init__(self, optimizer: Optimizer, end_lr: float, num_iter: int, last_epoch: int = -1) -> None: self.end_lr = end_lr self.num_iter = num_iter super(_ExponentialLR, self).__init__(optimizer, last_epoch) - def get_lr(self): - curr_iter = self.last_epoch + 1 + def get_lr(self) -> List[float]: # type: ignore + curr_iter = self.last_epoch + 1 # type: ignore[attr-defined] r = curr_iter / self.num_iter - return [base_lr * (self.end_lr / base_lr) ** r for base_lr in self.base_lrs] + return [base_lr * (self.end_lr / base_lr) ** r for base_lr in self.base_lrs] # type: ignore[attr-defined] diff --git a/ignite/contrib/handlers/mlflow_logger.py b/ignite/contrib/handlers/mlflow_logger.py index bc8478f0d17a..d55b9698363f 100644 --- a/ignite/contrib/handlers/mlflow_logger.py +++ b/ignite/contrib/handlers/mlflow_logger.py @@ -1,12 +1,12 @@ import numbers import warnings -from typing import Any, Callable, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import torch from torch.optim import Optimizer from ignite.contrib.handlers.base_logger import BaseLogger, BaseOptimizerParamsHandler, BaseOutputHandler -from ignite.engine import Engine, EventEnum +from ignite.engine import Engine, Events from ignite.handlers import global_step_from_engine __all__ = ["MLflowLogger", "OutputHandler", "OptimizerParamsHandler", "global_step_from_engine"] @@ -86,7 +86,7 @@ class MLflowLogger(BaseLogger): ) """ - def __init__(self, tracking_uri: Optional[str] = None): + def __init__(self, tracking_uri: Optional[str] = None) -> None: try: import mlflow except ImportError: @@ -102,21 +102,21 @@ def __init__(self, tracking_uri: Optional[str] = None): if self.active_run is None: self.active_run = mlflow.start_run() - def __getattr__(self, attr: Any): + def __getattr__(self, attr: Any) -> Any: import mlflow return getattr(mlflow, attr) - def close(self): + def close(self) -> None: import mlflow mlflow.end_run() - def _create_output_handler(self, *args: Any, **kwargs: Any): + def _create_output_handler(self, *args: Any, **kwargs: Any) -> "OutputHandler": return OutputHandler(*args, **kwargs) - def _create_opt_params_handler(self, *args: Any, **kwargs: Any): + def _create_opt_params_handler(self, *args: Any, **kwargs: Any) -> "OptimizerParamsHandler": return OptimizerParamsHandler(*args, **kwargs) @@ -212,17 +212,17 @@ def __init__( metric_names: Optional[Union[str, List[str]]] = None, output_transform: Optional[Callable] = None, global_step_transform: Optional[Callable] = None, - ): + ) -> None: super(OutputHandler, self).__init__(tag, metric_names, output_transform, global_step_transform) - def __call__(self, engine: Engine, logger: MLflowLogger, event_name: Union[str, EventEnum]): + def __call__(self, engine: Engine, logger: MLflowLogger, event_name: Union[str, Events]) -> None: if not isinstance(logger, MLflowLogger): raise TypeError("Handler 'OutputHandler' works only with MLflowLogger") metrics = self._setup_output_metrics(engine) - global_step = self.global_step_transform(engine, event_name) + global_step = self.global_step_transform(engine, event_name) # type: ignore[misc] if not isinstance(global_step, int): raise TypeError( @@ -230,10 +230,10 @@ def __call__(self, engine: Engine, logger: MLflowLogger, event_name: Union[str, " Please check the output of global_step_transform.".format(type(global_step)) ) - rendered_metrics = {} + rendered_metrics = {} # type: Dict[str, float] for key, value in metrics.items(): if isinstance(value, numbers.Number): - rendered_metrics["{} {}".format(self.tag, key)] = value + rendered_metrics["{} {}".format(self.tag, key)] = value # type: ignore[assignment] elif isinstance(value, torch.Tensor) and value.ndimension() == 0: rendered_metrics["{} {}".format(self.tag, key)] = value.item() elif isinstance(value, torch.Tensor) and value.ndimension() == 1: @@ -290,10 +290,10 @@ class OptimizerParamsHandler(BaseOptimizerParamsHandler): tag (str, optional): common title for all produced plots. For example, 'generator' """ - def __init__(self, optimizer: Optimizer, param_name: str = "lr", tag: Optional[str] = None): + def __init__(self, optimizer: Optimizer, param_name: str = "lr", tag: Optional[str] = None) -> None: super(OptimizerParamsHandler, self).__init__(optimizer, param_name, tag) - def __call__(self, engine: Engine, logger: MLflowLogger, event_name: Union[str, EventEnum]): + def __call__(self, engine: Engine, logger: MLflowLogger, event_name: Union[str, Events]) -> None: if not isinstance(logger, MLflowLogger): raise TypeError("Handler OptimizerParamsHandler works only with MLflowLogger") diff --git a/ignite/contrib/handlers/neptune_logger.py b/ignite/contrib/handlers/neptune_logger.py index 18cf523ee62d..98216a26206c 100644 --- a/ignite/contrib/handlers/neptune_logger.py +++ b/ignite/contrib/handlers/neptune_logger.py @@ -15,7 +15,7 @@ BaseOutputHandler, BaseWeightsScalarHandler, ) -from ignite.engine import Engine, EventEnum +from ignite.engine import Engine, Events from ignite.handlers import global_step_from_engine from ignite.handlers.checkpoint import BaseSaveHandler @@ -176,13 +176,13 @@ def score_function(engine): """ - def __getattr__(self, attr: Any): + def __getattr__(self, attr: Any) -> Any: import neptune return getattr(neptune, attr) - def __init__(self, *args: Any, **kwargs: Any): + def __init__(self, *args: Any, **kwargs: Any) -> None: try: import neptune except ImportError: @@ -205,13 +205,13 @@ def __init__(self, *args: Any, **kwargs: Any): self.experiment = neptune.create_experiment(**self._experiment_kwargs) - def close(self): + def close(self) -> None: self.stop() - def _create_output_handler(self, *args: Any, **kwargs: Any): + def _create_output_handler(self, *args: Any, **kwargs: Any) -> "OutputHandler": return OutputHandler(*args, **kwargs) - def _create_opt_params_handler(self, *args: Any, **kwargs: Any): + def _create_opt_params_handler(self, *args: Any, **kwargs: Any) -> "OptimizerParamsHandler": return OptimizerParamsHandler(*args, **kwargs) @@ -323,17 +323,17 @@ def __init__( metric_names: Optional[Union[str, List[str]]] = None, output_transform: Optional[Callable] = None, global_step_transform: Optional[Callable] = None, - ): + ) -> None: super(OutputHandler, self).__init__(tag, metric_names, output_transform, global_step_transform) - def __call__(self, engine: Engine, logger: NeptuneLogger, event_name: Union[str, EventEnum]): + def __call__(self, engine: Engine, logger: NeptuneLogger, event_name: Union[str, Events]) -> None: if not isinstance(logger, NeptuneLogger): raise TypeError("Handler OutputHandler works only with NeptuneLogger") metrics = self._setup_output_metrics(engine) - global_step = self.global_step_transform(engine, event_name) + global_step = self.global_step_transform(engine, event_name) # type: ignore[misc] if not isinstance(global_step, int): raise TypeError( @@ -391,10 +391,10 @@ class OptimizerParamsHandler(BaseOptimizerParamsHandler): tag (str, optional): common title for all produced plots. For example, "generator" """ - def __init__(self, optimizer: Optimizer, param_name: str = "lr", tag: Optional[str] = None): + def __init__(self, optimizer: Optimizer, param_name: str = "lr", tag: Optional[str] = None) -> None: super(OptimizerParamsHandler, self).__init__(optimizer, param_name, tag) - def __call__(self, engine: Engine, logger: NeptuneLogger, event_name: Union[str, EventEnum]): + def __call__(self, engine: Engine, logger: NeptuneLogger, event_name: Union[str, Events]) -> None: if not isinstance(logger, NeptuneLogger): raise TypeError("Handler OptimizerParamsHandler works only with NeptuneLogger") @@ -445,10 +445,10 @@ class WeightsScalarHandler(BaseWeightsScalarHandler): """ - def __init__(self, model: nn.Module, reduction: Callable = torch.norm, tag: Optional[str] = None): + def __init__(self, model: nn.Module, reduction: Callable = torch.norm, tag: Optional[str] = None) -> None: super(WeightsScalarHandler, self).__init__(model, reduction, tag=tag) - def __call__(self, engine: Engine, logger: NeptuneLogger, event_name: Union[str, EventEnum]): + def __call__(self, engine: Engine, logger: NeptuneLogger, event_name: Union[str, Events]) -> None: if not isinstance(logger, NeptuneLogger): raise TypeError("Handler WeightsScalarHandler works only with NeptuneLogger") @@ -503,10 +503,10 @@ class GradsScalarHandler(BaseWeightsScalarHandler): """ - def __init__(self, model: nn.Module, reduction: Callable = torch.norm, tag: Optional[str] = None): + def __init__(self, model: nn.Module, reduction: Callable = torch.norm, tag: Optional[str] = None) -> None: super(GradsScalarHandler, self).__init__(model, reduction, tag=tag) - def __call__(self, engine: Engine, logger: NeptuneLogger, event_name: Any): + def __call__(self, engine: Engine, logger: NeptuneLogger, event_name: Union[str, Events]) -> None: if not isinstance(logger, NeptuneLogger): raise TypeError("Handler GradsScalarHandler works only with NeptuneLogger") @@ -590,7 +590,7 @@ def __call__(self, checkpoint: Mapping, filename: str, metadata: Optional[Mappin with tempfile.NamedTemporaryFile() as tmp: # we can not use tmp.name to open tmp.file twice on Win32 # https://docs.python.org/3/library/tempfile.html#tempfile.NamedTemporaryFile - torch.save(checkpoint, tmp.file) + torch.save(checkpoint, tmp.file) # type: ignore[attr-defined] self._logger.log_artifact(tmp.name, filename) @idist.one_rank_only(with_barrier=True) diff --git a/ignite/contrib/handlers/param_scheduler.py b/ignite/contrib/handlers/param_scheduler.py index 024d477c82ae..f08a5bfdf748 100644 --- a/ignite/contrib/handlers/param_scheduler.py +++ b/ignite/contrib/handlers/param_scheduler.py @@ -6,7 +6,7 @@ from collections import OrderedDict from copy import copy from pathlib import Path -from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union, cast import torch from torch.optim.lr_scheduler import _LRScheduler @@ -40,7 +40,7 @@ def __init__( param_name: str, save_history: bool = False, param_group_index: Optional[int] = None, - ): + ) -> None: if not ( isinstance(optimizer, Optimizer) @@ -54,11 +54,11 @@ def __init__( self.optimizer = optimizer self.param_group_index = param_group_index self.param_name = param_name - self.save_history = save_history self.event_index = 0 + self._save_history = save_history self._state_attrs = ["event_index", "param_name", "save_history", "param_group_index"] - def __call__(self, engine: Engine, name: Optional[str] = None): + def __call__(self, engine: Optional[Engine], name: Optional[str] = None) -> None: value = self.get_param() @@ -79,21 +79,29 @@ def __call__(self, engine: Engine, name: Optional[str] = None): if name is None: name = self.param_name - if self.save_history: - if not hasattr(engine.state, "param_history") or engine.state.param_history is None: + if self.save_history and engine: + if not hasattr(engine.state, "param_history") or engine.state.param_history is None: # type: ignore setattr(engine.state, "param_history", {}) - engine.state.param_history.setdefault(name, []) + engine.state.param_history.setdefault(name, []) # type: ignore[attr-defined] values = [pg[self.param_name] for pg in self.optimizer_param_groups] - engine.state.param_history[name].append(values) + engine.state.param_history[name].append(values) # type: ignore[attr-defined] self.event_index += 1 @property - def optimizer_param_groups(self): + def optimizer_param_groups(self) -> List[Dict[str, Any]]: if self.param_group_index is None: return self.optimizer.param_groups return [self.optimizer.param_groups[self.param_group_index]] - def state_dict(self): + @property + def save_history(self) -> bool: + return self._save_history + + @save_history.setter + def save_history(self, value: bool) -> None: + self._save_history = value + + def state_dict(self) -> Dict[str, Any]: """Returns a dictionary containing a whole state of ParamScheduler. Returns: @@ -109,7 +117,7 @@ def state_dict(self): destination[name] = copy(val) return destination - def load_state_dict(self, state_dict: Mapping): + def load_state_dict(self, state_dict: Mapping) -> None: """Copies parameters from :attr:`state_dict` into this ParamScheduler. Args: @@ -142,7 +150,7 @@ def get_param(self) -> Union[List[float], float]: pass @classmethod - def simulate_values(cls, num_events: int, **scheduler_kwargs: Any): + def simulate_values(cls, num_events: int, **scheduler_kwargs: Any) -> List[List[int]]: """Method to simulate scheduled values during `num_events` events. Args: @@ -178,7 +186,7 @@ def simulate_values(cls, num_events: int, **scheduler_kwargs: Any): return values @classmethod - def plot_values(cls, num_events: int, **scheduler_kwargs: Mapping): + def plot_values(cls, num_events: int, **scheduler_kwargs: Mapping) -> Any: """Method to plot simulated scheduled values during `num_events` events. This class requires `matplotlib package `_ to be installed: @@ -285,10 +293,10 @@ def __init__( "end_value_mult", ] - def __call__(self, engine: Engine, name: Optional[str] = None): + def __call__(self, engine: Optional[Engine], name: Optional[str] = None) -> None: if self.event_index != 0 and self.event_index % self.cycle_size == 0: self.event_index = 0 - self.cycle_size *= self.cycle_mult + self.cycle_size = int(self.cycle_size * self.cycle_mult) self.cycle += 1 self.start_value *= self.start_value_mult self.end_value *= self.end_value_mult @@ -335,7 +343,7 @@ class LinearCyclicalScheduler(CyclicalScheduler): # """ - def get_param(self): + def get_param(self) -> float: cycle_progress = self.event_index / self.cycle_size return self.end_value + (self.start_value - self.end_value) * abs(cycle_progress - 0.5) * 2 @@ -401,7 +409,7 @@ class CosineAnnealingScheduler(CyclicalScheduler): Applications of Computer Vision (WACV), 2017 IEEE Winter Conference on. IEEE, 2017 """ - def get_param(self): + def get_param(self) -> float: """Method to get current optimizer's parameter value """ cycle_progress = self.event_index / self.cycle_size @@ -441,7 +449,7 @@ class ConcatScheduler(ParamScheduler): """ - def __init__(self, schedulers: List[ParamScheduler], durations: List[int], save_history: bool = False): + def __init__(self, schedulers: List[ParamScheduler], durations: List[int], save_history: bool = False) -> None: if not isinstance(schedulers, Sequence): raise TypeError("Argument schedulers should be a sequence, but given {}".format(schedulers)) @@ -474,17 +482,17 @@ def __init__(self, schedulers: List[ParamScheduler], durations: List[int], save_ self.schedulers = schedulers self.durations = durations - param_optimizers = [s.optimizer for s in self.schedulers] - param_optimizers = [s if isinstance(s, list) else [s] for s in param_optimizers] - param_optimizers = list(itertools.chain(*param_optimizers)) + tmp_optimizers = [s.optimizer for s in self.schedulers] + tmps_list_optimizers = [s if isinstance(s, list) else [s] for s in tmp_optimizers] + param_optimizers = list(itertools.chain(*tmps_list_optimizers)) optimizer = list(set(param_optimizers)) if len(optimizer) != 1: raise ValueError("schedulers should be related to same optimizer") - param_names = [s.param_name for s in self.schedulers] - param_names = [s if isinstance(s, list) else [s] for s in param_names] - param_names = list(itertools.chain(*param_names)) + tmp_param_names = [s.param_name for s in self.schedulers] + tmp_list_param_names = [s if isinstance(s, list) else [s] for s in tmp_param_names] + param_names = list(itertools.chain(*tmp_list_param_names)) param_name = list(set(param_names)) if len(param_name) != 1: @@ -499,12 +507,12 @@ def __init__(self, schedulers: List[ParamScheduler], durations: List[int], save_ ) self._scheduler_index = 0 - self._current_scheduler = None - self._current_duration = None + # self._current_scheduler = None + # self._current_duration = None self._setup_scheduler() self._state_attrs += ["_current_duration", "durations", "_scheduler_index"] - def state_dict(self): + def state_dict(self) -> Dict[str, Any]: """Returns a dictionary containing a whole state of ConcatScheduler. Returns: @@ -518,7 +526,7 @@ def state_dict(self): state_dict["schedulers"].append(s.state_dict()) return state_dict - def load_state_dict(self, state_dict: Mapping): + def load_state_dict(self, state_dict: Mapping) -> None: """Copies parameters from :attr:`state_dict` into this ConcatScheduler. Args: @@ -545,13 +553,13 @@ def load_state_dict(self, state_dict: Mapping): super(ConcatScheduler, self).load_state_dict(state_dict) self._setup_scheduler() - def _setup_scheduler(self): + def _setup_scheduler(self) -> None: self._current_scheduler = self.schedulers[self._scheduler_index] self._current_duration = ( self.durations[self._scheduler_index] if self._scheduler_index < len(self.durations) else -1 ) - def __call__(self, engine: Engine, name: Optional[str] = None): + def __call__(self, engine: Optional[Engine], name: Optional[str] = None) -> None: if self._current_duration == 0: self._scheduler_index += 1 self._setup_scheduler() @@ -559,32 +567,32 @@ def __call__(self, engine: Engine, name: Optional[str] = None): self._current_duration -= 1 @property - def optimizer_param_groups(self): + def optimizer_param_groups(self) -> List[Dict[str, Any]]: # We need to setup optimizer_param_groups as property # to synchonize with the latest _current_scheduler and its internal optimizer_param_groups return self._current_scheduler.optimizer_param_groups @property - def save_history(self): + def save_history(self) -> bool: return self._current_scheduler.save_history @save_history.setter - def save_history(self, value: bool): + def save_history(self, value: bool) -> None: for s in self.schedulers: s.save_history = value - def get_param(self): + def get_param(self) -> Union[List[float], float]: return self._current_scheduler.get_param() @classmethod - def simulate_values( + def simulate_values( # type: ignore[override] cls, num_events: int, schedulers: List[ParamScheduler], durations: List[int], param_names: Optional[Union[List[str], Tuple[str]]] = None, **kwargs: Any - ): + ) -> List[List[int]]: """Method to simulate scheduled values during num_events events. Args: @@ -606,15 +614,15 @@ def simulate_values( "Argument param_names should be list or tuple of strings, but given {}".format(param_names) ) - param_optimizers = [s.optimizer for s in schedulers] - param_optimizers = [s if isinstance(s, list) else [s] for s in param_optimizers] - param_optimizers = list(itertools.chain(*param_optimizers)) + tmp_param_optimizers = [s.optimizer for s in schedulers] + tmp_list_param_optimizers = [s if isinstance(s, list) else [s] for s in tmp_param_optimizers] + param_optimizers = list(itertools.chain(*tmp_list_param_optimizers)) - optimizer = list(set(param_optimizers)) - if len(optimizer) != 1: + tmp_optimizer = list(set(param_optimizers)) + if len(tmp_optimizer) != 1: raise ValueError("schedulers should be related to same optimizer") - optimizer = optimizer[0] + optimizer = tmp_optimizer[0] # This scheduler uses `ParamScheduler` which # should be replicated in order to simulate LR values and @@ -632,7 +640,9 @@ def simulate_values( s.save_history = False output = [] - scheduler = cls(schedulers=schedulers, save_history=False, durations=durations, **kwargs) + scheduler = cls( # type: ignore[call-arg] + schedulers=schedulers, save_history=False, durations=durations, **kwargs + ) if param_names is None: param_names = [scheduler.param_name] for i in range(num_events): @@ -675,7 +685,7 @@ class LRScheduler(ParamScheduler): trainer.add_event_handler(Events.ITERATION_COMPLETED, scheduler) """ - def __init__(self, lr_scheduler: _LRScheduler, save_history=False): + def __init__(self, lr_scheduler: _LRScheduler, save_history: bool = False) -> None: if not isinstance(lr_scheduler, _LRScheduler): raise TypeError( @@ -685,28 +695,32 @@ def __init__(self, lr_scheduler: _LRScheduler, save_history=False): self.lr_scheduler = lr_scheduler super(LRScheduler, self).__init__( - optimizer=self.lr_scheduler.optimizer, param_name="lr", save_history=save_history + optimizer=self.lr_scheduler.optimizer, # type: ignore[attr-defined] + param_name="lr", + save_history=save_history, ) self._state_attrs += ["lr_scheduler"] - def __call__(self, engine: Engine, name: Optional[str] = None): - self.lr_scheduler.last_epoch += 1 + def __call__(self, engine: Optional[Engine], name: Optional[str] = None) -> None: + self.lr_scheduler.last_epoch += 1 # type: ignore[attr-defined] super(LRScheduler, self).__call__(engine, name) def get_param(self) -> Union[float, List[float]]: """Method to get current optimizer's parameter value """ # Emulate context manager for pytorch>=1.4 - self.lr_scheduler._get_lr_called_within_step = True - lr_list = self.lr_scheduler.get_lr() - self.lr_scheduler._get_lr_called_within_step = False + self.lr_scheduler._get_lr_called_within_step = True # type: ignore[attr-defined] + lr_list = cast(List[float], self.lr_scheduler.get_lr()) + self.lr_scheduler._get_lr_called_within_step = False # type: ignore[attr-defined] if len(lr_list) == 1: return lr_list[0] else: return lr_list @classmethod - def simulate_values(cls, num_events: int, lr_scheduler: _LRScheduler, **kwargs: Any): + def simulate_values( # type: ignore[override] + cls, num_events: int, lr_scheduler: _LRScheduler, **kwargs: Any + ) -> List[List[int]]: """Method to simulate scheduled values during num_events events. Args: @@ -729,11 +743,14 @@ def simulate_values(cls, num_events: int, lr_scheduler: _LRScheduler, **kwargs: # not perturb original scheduler. with tempfile.TemporaryDirectory() as tmpdirname: cache_filepath = Path(tmpdirname) / "ignite_lr_scheduler_cache.pt" - obj = {"lr_scheduler": lr_scheduler.state_dict(), "optimizer": lr_scheduler.optimizer.state_dict()} + obj = { + "lr_scheduler": lr_scheduler.state_dict(), + "optimizer": lr_scheduler.optimizer.state_dict(), # type: ignore[attr-defined] + } torch.save(obj, cache_filepath.as_posix()) values = [] - scheduler = cls(save_history=False, lr_scheduler=lr_scheduler, **kwargs) + scheduler = cls(save_history=False, lr_scheduler=lr_scheduler, **kwargs) # type: ignore[call-arg] for i in range(num_events): params = [p[scheduler.param_name] for p in scheduler.optimizer_param_groups] values.append([i] + params) @@ -741,7 +758,7 @@ def simulate_values(cls, num_events: int, lr_scheduler: _LRScheduler, **kwargs: obj = torch.load(cache_filepath.as_posix()) lr_scheduler.load_state_dict(obj["lr_scheduler"]) - lr_scheduler.optimizer.load_state_dict(obj["optimizer"]) + lr_scheduler.optimizer.load_state_dict(obj["optimizer"]) # type: ignore[attr-defined] return values @@ -753,7 +770,7 @@ def create_lr_scheduler_with_warmup( warmup_end_value: Optional[float] = None, save_history: bool = False, output_simulated_values: Optional[List] = None, -): +) -> "ConcatScheduler": """ Helper method to create a learning rate scheduler with a linear warm-up. @@ -809,9 +826,9 @@ def create_lr_scheduler_with_warmup( if not (warmup_duration > 1): raise ValueError("Argument warmup_duration should be at least 2 events, but given {}".format(warmup_duration)) - warmup_schedulers = [] + warmup_schedulers = [] # type: List[ParamScheduler] - for param_group_index, param_group in enumerate(lr_scheduler.optimizer.param_groups): + for param_group_index, param_group in enumerate(lr_scheduler.optimizer.param_groups): # type: ignore[union-attr] if warmup_end_value is None: param_group_warmup_end_value = param_group["lr"] @@ -836,21 +853,28 @@ def create_lr_scheduler_with_warmup( else: milestones_values.pop(-1) - warmup_scheduler = PiecewiseLinear( - lr_scheduler.optimizer, - param_name="lr", - milestones_values=milestones_values, - param_group_index=param_group_index, - save_history=save_history, + warmup_schedulers.append( + PiecewiseLinear( + lr_scheduler.optimizer, + param_name="lr", + milestones_values=milestones_values, + param_group_index=param_group_index, + save_history=save_history, + ) ) - warmup_schedulers.append(warmup_scheduler) - warmup_scheduler = ParamGroupScheduler(warmup_schedulers, save_history=save_history) - schedulers = [warmup_scheduler, lr_scheduler] + schedulers = [ + warmup_scheduler, + lr_scheduler, + ] # type: List[Union[ParamScheduler, ParamGroupScheduler, _LRScheduler]] durations = [milestones_values[-1][0] + 1] - combined_scheduler = ConcatScheduler(schedulers, durations=durations, save_history=save_history) + combined_scheduler = ConcatScheduler( + schedulers, # type: ignore[arg-type] + durations=durations, + save_history=save_history, + ) if output_simulated_values is not None: if not isinstance(output_simulated_values, list): @@ -859,7 +883,11 @@ def create_lr_scheduler_with_warmup( "but given {}.".format(type(output_simulated_values)) ) num_events = len(output_simulated_values) - result = ConcatScheduler.simulate_values(num_events=num_events, schedulers=schedulers, durations=durations) + result = ConcatScheduler.simulate_values( + num_events=num_events, + schedulers=schedulers, # type: ignore[arg-type] + durations=durations, + ) for i in range(num_events): output_simulated_values[i] = result[i] return combined_scheduler @@ -916,10 +944,10 @@ def __init__( "but given {}".format(milestones_values) ) - values = [] - milestones = [] + values = [] # type: List[float] + milestones = [] # type: List[int] for pair in milestones_values: - if not isinstance(pair, Sequence) or len(pair) != 2: + if not isinstance(pair, tuple) or len(pair) != 2: raise ValueError("Argument milestones_values should be a list of pairs (milestone, param_value)") if not isinstance(pair[0], numbers.Integral): raise TypeError("Value of a milestone should be integer, but given {}".format(type(pair[0]))) @@ -936,7 +964,7 @@ def __init__( self._index = 0 self._state_attrs += ["values", "milestones", "_index"] - def _get_start_end(self): + def _get_start_end(self) -> Tuple[int, int, float, float]: if self.milestones[0] > self.event_index: return self.event_index - 1, self.event_index, self.values[0], self.values[0] elif self.milestones[-1] <= self.event_index: @@ -952,7 +980,7 @@ def _get_start_end(self): self._index += 1 return self._get_start_end() - def get_param(self): + def get_param(self) -> float: start_index, end_index, start_value, end_value = self._get_start_end() return start_value + (end_value - start_value) * (self.event_index - start_index) / (end_index - start_index) @@ -1018,37 +1046,37 @@ def __init__(self, schedulers: List[ParamScheduler], names: Optional[List[str]] self.optimizer = [s.optimizer for s in self.schedulers] self.param_name = [s.param_name for s in self.schedulers] - def __call__(self, engine: Engine, name: Optional[str] = None): + def __call__(self, engine: Optional[Engine], name: Optional[str] = None) -> None: for scheduler, name in zip(self.schedulers, self.names): scheduler(engine, name) @property - def optimizer_param_groups(self): + def optimizer_param_groups(self) -> List[Dict[str, Any]]: return [pg for scheduler in self.schedulers for pg in scheduler.optimizer_param_groups] @property - def save_history(self): + def save_history(self) -> bool: return self.schedulers[0].save_history @save_history.setter - def save_history(self, value: bool): + def save_history(self, value: bool) -> None: for s in self.schedulers: s.save_history = value - def state_dict(self): + def state_dict(self) -> Dict[str, List[Any]]: """Returns a dictionary containing a whole state of ParamGroupScheduler. Returns: dict: a dictionary containing a whole state of ParamGroupScheduler """ - state_dict = OrderedDict() + state_dict = OrderedDict() # type: Dict[str, List[Any]] state_dict["schedulers"] = [] for n, s in zip(self.names, self.schedulers): state_dict["schedulers"].append((n, s.state_dict())) return state_dict - def load_state_dict(self, state_dict: Mapping): + def load_state_dict(self, state_dict: Mapping) -> None: """Copies parameters from :attr:`state_dict` into this ParamScheduler. Args: @@ -1079,7 +1107,7 @@ def load_state_dict(self, state_dict: Mapping): s.load_state_dict(sd) @classmethod - def simulate_values(cls, num_events: int, schedulers: _LRScheduler, **kwargs: Any): + def simulate_values(cls, num_events: int, schedulers: List[_LRScheduler], **kwargs: Any) -> List[List[int]]: """Method to simulate scheduled values during num_events events. Args: @@ -1098,26 +1126,28 @@ def simulate_values(cls, num_events: int, schedulers: _LRScheduler, **kwargs: An cache_filepath = Path(tmpdirname) / "ignite_lr_scheduler_cache.pt" objs = {"lr_scheduler_{}".format(i): s.state_dict() for i, s in enumerate(schedulers)} # all schedulers should be related to the same optimizer - objs["optimizer"] = schedulers[0].optimizer.state_dict() + objs["optimizer"] = schedulers[0].optimizer.state_dict() # type: ignore[attr-defined] torch.save(objs, cache_filepath.as_posix()) values = [] - scheduler = cls(schedulers=schedulers, **kwargs) + scheduler = cls(schedulers=schedulers, **kwargs) # type: ignore[arg-type] for i in range(num_events): - params = [scheduler.get_param() for scheduler in schedulers] + params = [scheduler.get_param() for scheduler in schedulers] # type: ignore[attr-defined] values.append([i] + params) scheduler(engine=None) objs = torch.load(cache_filepath.as_posix()) for i, s in enumerate(schedulers): s.load_state_dict(objs["lr_scheduler_{}".format(i)]) - s.optimizer.load_state_dict(objs["optimizer"]) + s.optimizer.load_state_dict(objs["optimizer"]) # type: ignore[attr-defined] return values -def _get_fake_optimizer(optimizer_cls: Optional[Optimizer] = None, **kwargs: Any): +def _get_fake_optimizer( + optimizer_cls: Optional[Union[Type[Optimizer], Type[torch.optim.SGD]]] = None, **kwargs: Any +) -> Union[Optimizer, torch.optim.SGD]: t = torch.zeros([1], requires_grad=True) if optimizer_cls is None: optimizer_cls = torch.optim.SGD diff --git a/ignite/contrib/handlers/polyaxon_logger.py b/ignite/contrib/handlers/polyaxon_logger.py index 9d75dff12481..70a2416aa569 100644 --- a/ignite/contrib/handlers/polyaxon_logger.py +++ b/ignite/contrib/handlers/polyaxon_logger.py @@ -1,12 +1,12 @@ import numbers import warnings -from typing import Any, Callable, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import torch from torch.optim import Optimizer from ignite.contrib.handlers.base_logger import BaseLogger, BaseOptimizerParamsHandler, BaseOutputHandler -from ignite.engine import Engine, EventEnum +from ignite.engine import Engine, Events from ignite.handlers import global_step_from_engine __all__ = ["PolyaxonLogger", "OutputHandler", "OptimizerParamsHandler", "global_step_from_engine"] @@ -91,7 +91,7 @@ class PolyaxonLogger(BaseLogger): """ - def __init__(self, *args: Any, **kwargs: Any): + def __init__(self, *args: Any, **kwargs: Any) -> None: try: from polyaxon_client.tracking import Experiment except ImportError: @@ -102,13 +102,13 @@ def __init__(self, *args: Any, **kwargs: Any): self.experiment = Experiment(*args, **kwargs) - def __getattr__(self, attr: Any): + def __getattr__(self, attr: Any) -> Any: return getattr(self.experiment, attr) - def _create_output_handler(self, *args: Any, **kwargs: Any): + def _create_output_handler(self, *args: Any, **kwargs: Any) -> "OutputHandler": return OutputHandler(*args, **kwargs) - def _create_opt_params_handler(self, *args: Any, **kwargs: Any): + def _create_opt_params_handler(self, *args: Any, **kwargs: Any) -> "OptimizerParamsHandler": return OptimizerParamsHandler(*args, **kwargs) @@ -204,17 +204,17 @@ def __init__( metric_names: Optional[List[str]] = None, output_transform: Optional[Callable] = None, global_step_transform: Optional[Callable] = None, - ): + ) -> None: super(OutputHandler, self).__init__(tag, metric_names, output_transform, global_step_transform) - def __call__(self, engine: Engine, logger: PolyaxonLogger, event_name: Union[str, EventEnum]): + def __call__(self, engine: Engine, logger: PolyaxonLogger, event_name: Union[str, Events]) -> None: if not isinstance(logger, PolyaxonLogger): raise RuntimeError("Handler 'OutputHandler' works only with PolyaxonLogger") metrics = self._setup_output_metrics(engine) - global_step = self.global_step_transform(engine, event_name) + global_step = self.global_step_transform(engine, event_name) # type: ignore[misc] if not isinstance(global_step, int): raise TypeError( @@ -222,7 +222,7 @@ def __call__(self, engine: Engine, logger: PolyaxonLogger, event_name: Union[str " Please check the output of global_step_transform.".format(type(global_step)) ) - rendered_metrics = {"step": global_step} + rendered_metrics = {"step": global_step} # type: Dict[str, Union[float, numbers.Number]] for key, value in metrics.items(): if isinstance(value, numbers.Number): rendered_metrics["{}/{}".format(self.tag, key)] = value @@ -269,10 +269,10 @@ class OptimizerParamsHandler(BaseOptimizerParamsHandler): tag (str, optional): common title for all produced plots. For example, "generator" """ - def __init__(self, optimizer: Optimizer, param_name: str = "lr", tag: Optional[str] = None): + def __init__(self, optimizer: Optimizer, param_name: str = "lr", tag: Optional[str] = None) -> None: super(OptimizerParamsHandler, self).__init__(optimizer, param_name, tag) - def __call__(self, engine: Engine, logger: PolyaxonLogger, event_name: Union[str, EventEnum]): + def __call__(self, engine: Engine, logger: PolyaxonLogger, event_name: Union[str, Events]) -> None: if not isinstance(logger, PolyaxonLogger): raise RuntimeError("Handler OptimizerParamsHandler works only with PolyaxonLogger") diff --git a/ignite/contrib/handlers/stores.py b/ignite/contrib/handlers/stores.py index 6bd2b5167bdb..51212bb445e1 100644 --- a/ignite/contrib/handlers/stores.py +++ b/ignite/contrib/handlers/stores.py @@ -1,4 +1,4 @@ -from typing import Callable, Optional +from typing import Callable, List, Tuple, Union from ignite.engine import Engine, Events @@ -33,17 +33,17 @@ def log_training_results(engine): """ - def __init__(self, output_transform: Callable = lambda x: x): - self.data = None + def __init__(self, output_transform: Callable = lambda x: x) -> None: + self.data = [] # type: List[Union[int, Tuple[int, int]]] self.output_transform = output_transform - def reset(self): + def reset(self) -> None: self.data = [] - def update(self, engine: Engine): + def update(self, engine: Engine) -> None: output = self.output_transform(engine.state.output) self.data.append(output) - def attach(self, engine: Engine): + def attach(self, engine: Engine) -> None: engine.add_event_handler(Events.EPOCH_STARTED, self.reset) engine.add_event_handler(Events.ITERATION_COMPLETED, self.update) diff --git a/ignite/contrib/handlers/tensorboard_logger.py b/ignite/contrib/handlers/tensorboard_logger.py index 798e11769f56..9acc2f1bda31 100644 --- a/ignite/contrib/handlers/tensorboard_logger.py +++ b/ignite/contrib/handlers/tensorboard_logger.py @@ -13,7 +13,7 @@ BaseWeightsHistHandler, BaseWeightsScalarHandler, ) -from ignite.engine import Engine, EventEnum +from ignite.engine import Engine, EventEnum, Events from ignite.handlers import global_step_from_engine __all__ = [ @@ -149,12 +149,12 @@ class TensorboardLogger(BaseLogger): """ - def __init__(self, *args: Any, **kwargs: Any): + def __init__(self, *args: Any, **kwargs: Any) -> None: try: from tensorboardX import SummaryWriter except ImportError: try: - from torch.utils.tensorboard import SummaryWriter + from torch.utils.tensorboard import SummaryWriter # type: ignore[no-redef] except ImportError: raise RuntimeError( "This contrib module requires either tensorboardX or torch >= 1.2.0. " @@ -164,13 +164,13 @@ def __init__(self, *args: Any, **kwargs: Any): self.writer = SummaryWriter(*args, **kwargs) - def close(self): + def close(self) -> None: self.writer.close() - def _create_output_handler(self, *args: Any, **kwargs: Any): + def _create_output_handler(self, *args: Any, **kwargs: Any) -> "OutputHandler": return OutputHandler(*args, **kwargs) - def _create_opt_params_handler(self, *args: Any, **kwargs: Any): + def _create_opt_params_handler(self, *args: Any, **kwargs: Any) -> "OptimizerParamsHandler": return OptimizerParamsHandler(*args, **kwargs) @@ -266,17 +266,17 @@ def __init__( metric_names: Optional[List[str]] = None, output_transform: Optional[Callable] = None, global_step_transform: Optional[Callable] = None, - ): + ) -> None: super(OutputHandler, self).__init__(tag, metric_names, output_transform, global_step_transform) - def __call__(self, engine: Engine, logger: TensorboardLogger, event_name: Union[str, EventEnum]): + def __call__(self, engine: Engine, logger: TensorboardLogger, event_name: Union[str, EventEnum]) -> None: if not isinstance(logger, TensorboardLogger): raise RuntimeError("Handler 'OutputHandler' works only with TensorboardLogger") metrics = self._setup_output_metrics(engine) - global_step = self.global_step_transform(engine, event_name) + global_step = self.global_step_transform(engine, event_name) # type: ignore[misc] if not isinstance(global_step, int): raise TypeError( "global_step must be int, got {}." @@ -325,10 +325,10 @@ class OptimizerParamsHandler(BaseOptimizerParamsHandler): tag (str, optional): common title for all produced plots. For example, "generator" """ - def __init__(self, optimizer: Optimizer, param_name: str = "lr", tag: Optional[str] = None): + def __init__(self, optimizer: Optimizer, param_name: str = "lr", tag: Optional[str] = None) -> None: super(OptimizerParamsHandler, self).__init__(optimizer, param_name, tag) - def __call__(self, engine: Engine, logger: TensorboardLogger, event_name: Union[str, EventEnum]): + def __call__(self, engine: Engine, logger: TensorboardLogger, event_name: Union[str, Events]) -> None: if not isinstance(logger, TensorboardLogger): raise RuntimeError("Handler OptimizerParamsHandler works only with TensorboardLogger") @@ -371,10 +371,10 @@ class WeightsScalarHandler(BaseWeightsScalarHandler): """ - def __init__(self, model: nn.Module, reduction: Callable = torch.norm, tag: Optional[str] = None): + def __init__(self, model: nn.Module, reduction: Callable = torch.norm, tag: Optional[str] = None) -> None: super(WeightsScalarHandler, self).__init__(model, reduction, tag=tag) - def __call__(self, engine: Engine, logger: TensorboardLogger, event_name: Union[str, EventEnum]): + def __call__(self, engine: Engine, logger: TensorboardLogger, event_name: Union[str, Events]) -> None: if not isinstance(logger, TensorboardLogger): raise RuntimeError("Handler 'WeightsScalarHandler' works only with TensorboardLogger") @@ -419,7 +419,7 @@ class WeightsHistHandler(BaseWeightsHistHandler): def __init__(self, model: nn.Module, tag: Optional[str] = None): super(WeightsHistHandler, self).__init__(model, tag=tag) - def __call__(self, engine: Engine, logger: TensorboardLogger, event_name: Union[str, EventEnum]): + def __call__(self, engine: Engine, logger: TensorboardLogger, event_name: Union[str, Events]) -> None: if not isinstance(logger, TensorboardLogger): raise RuntimeError("Handler 'WeightsHistHandler' works only with TensorboardLogger") @@ -465,10 +465,10 @@ class GradsScalarHandler(BaseWeightsScalarHandler): """ - def __init__(self, model: nn.Module, reduction: Callable = torch.norm, tag: Optional[str] = None): + def __init__(self, model: nn.Module, reduction: Callable = torch.norm, tag: Optional[str] = None) -> None: super(GradsScalarHandler, self).__init__(model, reduction, tag=tag) - def __call__(self, engine: Engine, logger: TensorboardLogger, event_name: Union[str, EventEnum]): + def __call__(self, engine: Engine, logger: TensorboardLogger, event_name: Union[str, Events]) -> None: if not isinstance(logger, TensorboardLogger): raise RuntimeError("Handler 'GradsScalarHandler' works only with TensorboardLogger") @@ -509,10 +509,10 @@ class GradsHistHandler(BaseWeightsHistHandler): """ - def __init__(self, model: nn.Module, tag: Optional[str] = None): + def __init__(self, model: nn.Module, tag: Optional[str] = None) -> None: super(GradsHistHandler, self).__init__(model, tag=tag) - def __call__(self, engine: Engine, logger: TensorboardLogger, event_name: Union[str, EventEnum]): + def __call__(self, engine: Engine, logger: TensorboardLogger, event_name: Union[str, Events]) -> None: if not isinstance(logger, TensorboardLogger): raise RuntimeError("Handler 'GradsHistHandler' works only with TensorboardLogger") diff --git a/ignite/contrib/handlers/time_profilers.py b/ignite/contrib/handlers/time_profilers.py index 81df27be8678..4acb015f65c9 100644 --- a/ignite/contrib/handlers/time_profilers.py +++ b/ignite/contrib/handlers/time_profilers.py @@ -1,6 +1,6 @@ import functools from collections import OrderedDict -from typing import Any, Callable, Dict, Iterable, List, Mapping, Sequence, Union +from typing import Any, Callable, Dict, List, Mapping, Sequence, Tuple, Union, cast import torch @@ -42,14 +42,14 @@ def log_intermediate_results(): Events.DATALOADER_STOP_ITERATION, ] - def __init__(self): + def __init__(self) -> None: self._dataflow_timer = Timer() self._processing_timer = Timer() self._event_handlers_timer = Timer() - self.dataflow_times = None - self.processing_times = None - self.event_handlers_times = None + self.dataflow_times = torch.zeros(1) + self.processing_times = torch.zeros(1) + self.event_handlers_times = {} # type: Dict[EventEnum, torch.Tensor] self._events = [ Events.EPOCH_STARTED, @@ -79,7 +79,7 @@ def __init__(self): self._as_last_completed, ] - def _reset(self, num_epochs: int, total_num_iters: int): + def _reset(self, num_epochs: int, total_num_iters: int) -> None: self.dataflow_times = torch.zeros(total_num_iters) self.processing_times = torch.zeros(total_num_iters) self.event_handlers_times = { @@ -93,13 +93,13 @@ def _reset(self, num_epochs: int, total_num_iters: int): Events.GET_BATCH_STARTED: torch.zeros(total_num_iters), } - def _as_first_started(self, engine: Engine): + def _as_first_started(self, engine: Engine) -> None: if hasattr(engine.state.dataloader, "__len__"): - num_iters_per_epoch = len(engine.state.dataloader) + num_iters_per_epoch = len(engine.state.dataloader) # type: ignore[arg-type] else: - num_iters_per_epoch = engine.state.epoch_length + num_iters_per_epoch = cast(int, engine.state.epoch_length) - self.max_epochs = engine.state.max_epochs + self.max_epochs = cast(int, engine.state.max_epochs) self.total_num_iters = self.max_epochs * num_iters_per_epoch self._reset(self.max_epochs, self.total_num_iters) @@ -125,30 +125,30 @@ def _as_first_started(self, engine: Engine): # Let's go self._event_handlers_timer.reset() - def _as_last_started(self, engine: Engine): + def _as_last_started(self, engine: Engine) -> None: self.event_handlers_times[Events.STARTED][0] = self._event_handlers_timer.value() - def _as_first_epoch_started(self, engine: Engine): + def _as_first_epoch_started(self, engine: Engine) -> None: self._event_handlers_timer.reset() - def _as_last_epoch_started(self, engine: Engine): + def _as_last_epoch_started(self, engine: Engine) -> None: t = self._event_handlers_timer.value() e = engine.state.epoch - 1 self.event_handlers_times[Events.EPOCH_STARTED][e] = t - def _as_first_get_batch_started(self, engine: Engine): + def _as_first_get_batch_started(self, engine: Engine) -> None: self._event_handlers_timer.reset() self._dataflow_timer.reset() - def _as_last_get_batch_started(self, engine: Engine): + def _as_last_get_batch_started(self, engine: Engine) -> None: t = self._event_handlers_timer.value() i = engine.state.iteration - 1 self.event_handlers_times[Events.GET_BATCH_STARTED][i] = t - def _as_first_get_batch_completed(self, engine: Engine): + def _as_first_get_batch_completed(self, engine: Engine) -> None: self._event_handlers_timer.reset() - def _as_last_get_batch_completed(self, engine: Engine): + def _as_last_get_batch_completed(self, engine: Engine) -> None: t = self._event_handlers_timer.value() i = engine.state.iteration - 1 self.event_handlers_times[Events.GET_BATCH_COMPLETED][i] = t @@ -158,40 +158,40 @@ def _as_last_get_batch_completed(self, engine: Engine): self._dataflow_timer.reset() - def _as_first_iter_started(self, engine: Engine): + def _as_first_iter_started(self, engine: Engine) -> None: self._event_handlers_timer.reset() - def _as_last_iter_started(self, engine: Engine): + def _as_last_iter_started(self, engine: Engine) -> None: t = self._event_handlers_timer.value() i = engine.state.iteration - 1 self.event_handlers_times[Events.ITERATION_STARTED][i] = t self._processing_timer.reset() - def _as_first_iter_completed(self, engine: Engine): + def _as_first_iter_completed(self, engine: Engine) -> None: t = self._processing_timer.value() i = engine.state.iteration - 1 self.processing_times[i] = t self._event_handlers_timer.reset() - def _as_last_iter_completed(self, engine: Engine): + def _as_last_iter_completed(self, engine: Engine) -> None: t = self._event_handlers_timer.value() i = engine.state.iteration - 1 self.event_handlers_times[Events.ITERATION_COMPLETED][i] = t - def _as_first_epoch_completed(self, engine: Engine): + def _as_first_epoch_completed(self, engine: Engine) -> None: self._event_handlers_timer.reset() - def _as_last_epoch_completed(self, engine: Engine): + def _as_last_epoch_completed(self, engine: Engine) -> None: t = self._event_handlers_timer.value() e = engine.state.epoch - 1 self.event_handlers_times[Events.EPOCH_COMPLETED][e] = t - def _as_first_completed(self, engine: Engine): + def _as_first_completed(self, engine: Engine) -> None: self._event_handlers_timer.reset() - def _as_last_completed(self, engine: Engine): + def _as_last_completed(self, engine: Engine) -> None: self.event_handlers_times[Events.COMPLETED][0] = self._event_handlers_timer.value() # Remove added handlers: @@ -203,7 +203,7 @@ def _as_last_completed(self, engine: Engine): for e, m in zip(self._events, self._lmethods): engine.remove_event_handler(m, e) - def attach(self, engine: Engine): + def attach(self, engine: Engine) -> None: if not isinstance(engine, Engine): raise TypeError("Argument engine should be ignite.engine.Engine, " "but given {}".format(type(engine))) @@ -211,10 +211,12 @@ def attach(self, engine: Engine): engine._event_handlers[Events.STARTED].insert(0, (self._as_first_started, (engine,), {})) @staticmethod - def _compute_basic_stats(data: Sequence): + def _compute_basic_stats(data: torch.Tensor) -> Dict[str, Union[str, float, Tuple[Union[float], Union[float]]]]: # compute on non-zero data: data = data[data > 0] - out = [("total", torch.sum(data).item() if len(data) > 0 else "not yet triggered")] + out = [ + ("total", torch.sum(data).item() if len(data) > 0 else "not yet triggered") + ] # type: List[Tuple[str, Union[str, float, Tuple[Union[float], Union[float]]]]] if len(data) > 1: out += [ ("min/index", (torch.min(data).item(), torch.argmin(data).item())), @@ -224,7 +226,7 @@ def _compute_basic_stats(data: Sequence): ] return OrderedDict(out) - def get_results(self): + def get_results(self) -> Dict[str, Dict[str, Any]]: """ Method to fetch the aggregated profiler results after the engine is run @@ -233,23 +235,23 @@ def get_results(self): results = profiler.get_results() """ - total_eh_time = sum([(self.event_handlers_times[e]).sum() for e in Events if e not in self.events_to_ignore]) + total_eh_time = sum( + [(self.event_handlers_times[e]).sum() for e in Events if e not in self.events_to_ignore] + ) # type: Union[int, torch.Tensor] + event_handlers_stats = dict( + [ + (str(e.name).replace(".", "_"), self._compute_basic_stats(self.event_handlers_times[e])) + for e in Events + if e not in self.events_to_ignore + ] + + [("total_time", total_eh_time)] # type: ignore[list-item] + ) return OrderedDict( [ ("processing_stats", self._compute_basic_stats(self.processing_times)), ("dataflow_stats", self._compute_basic_stats(self.dataflow_times)), - ( - "event_handlers_stats", - dict( - [ - (str(e.name).replace(".", "_"), self._compute_basic_stats(self.event_handlers_times[e])) - for e in Events - if e not in self.events_to_ignore - ] - + [("total_time", total_eh_time)] - ), - ), + ("event_handlers_stats", event_handlers_stats,), ( "event_handlers_names", {str(e.name).replace(".", "_") + "_names": v for e, v in self.event_handlers_names.items()}, @@ -257,7 +259,7 @@ def get_results(self): ] ) - def write_results(self, output_path: str): + def write_results(self, output_path: str) -> None: """ Method to store the unaggregated profiling results to a csv file @@ -336,7 +338,7 @@ def write_results(self, output_path: str): results_df.to_csv(output_path, index=False) @staticmethod - def print_results(results: Union[Dict, Iterable]): + def print_results(results: Dict) -> str: """ Method to print the aggregated results from the profiler @@ -382,14 +384,14 @@ def print_results(results: Union[Dict, Iterable]): """ - def to_str(v: Union[str, tuple]): + def to_str(v: Union[str, tuple]) -> str: if isinstance(v, str): return v elif isinstance(v, tuple): return "{:.5f}/{}".format(v[0], v[1]) return "{:.5f}".format(v) - def odict_to_str(d: Mapping): + def odict_to_str(d: Mapping) -> str: out = " | ".join([to_str(v) for v in d.values()]) return out @@ -470,21 +472,21 @@ def log_intermediate_results(): EVENT_FILTER_THESHOLD_TIME = 0.0001 - def __init__(self): + def __init__(self) -> None: self._dataflow_timer = Timer() self._processing_timer = Timer() self._event_handlers_timer = Timer() - self.dataflow_times = None - self.processing_times = None - self.event_handlers_times = None + self.dataflow_times = [] # type: List[float] + self.processing_times = [] # type: List[float] + self.event_handlers_times = {} # type: Dict[EventEnum, Dict[str, List[float]]] @staticmethod def _get_callable_name(handler: Callable) -> str: # get name of the callable handler return getattr(handler, "__qualname__", handler.__class__.__name__) - def _create_wrapped_handler(self, handler: Callable, event: Events) -> Callable: + def _create_wrapped_handler(self, handler: Callable, event: EventEnum) -> Callable: @functools.wraps(handler) def _timeit_handler(*args: Any, **kwargs: Any) -> None: self._event_handlers_timer.reset() @@ -583,15 +585,17 @@ def get_results(self) -> List[List[Union[str, float]]]: ) total_eh_time = round(float(total_eh_time), 5,) - def compute_basic_stats(data: Sequence) -> List[Union[str, float]]: - data = torch.tensor(data, dtype=torch.float32) + def compute_basic_stats( + times: Union[Sequence, torch.Tensor] + ) -> List[Union[str, float, Tuple[Union[str, float], Union[str, float]]]]: + data = torch.tensor(times, dtype=torch.float32) # compute on non-zero data: data = data[data > 0] - total = round(torch.sum(data).item(), 5) if len(data) > 0 else "not triggered" - min_index = ("None", "None") - max_index = ("None", "None") - mean = "None" - std = "None" + total = round(torch.sum(data).item(), 5) if len(data) > 0 else "not triggered" # type: Union[str, float] + min_index = ("None", "None") # type: Tuple[Union[str, float], Union[str, float]] + max_index = ("None", "None") # type: Tuple[Union[str, float], Union[str, float]] + mean = "None" # type: Union[str, float] + std = "None" # type: Union[str, float] if len(data) > 0: min_index = (round(torch.min(data).item(), 5), torch.argmin(data).item()) max_index = (round(torch.max(data).item(), 5), torch.argmax(data).item()) @@ -695,8 +699,8 @@ def print_results(results: List[List[Union[str, float]]]) -> None: """ # adopted implementation of torch.autograd.profiler.build_table - handler_column_width = max([len(item[0]) for item in results]) + 4 - event_column_width = max([len(item[1]) for item in results]) + 4 + handler_column_width = max([len(item[0]) for item in results]) + 4 # type: ignore[arg-type] + event_column_width = max([len(item[1]) for item in results]) + 4 # type: ignore[arg-type] DEFAULT_COLUMN_WIDTH = 14 @@ -742,8 +746,8 @@ def append(s: str) -> None: for row in results[:-3]: # format min/idx and max/idx - row[3] = "{}/{}".format(*row[3]) - row[4] = "{}/{}".format(*row[4]) + row[3] = "{}/{}".format(*row[3]) # type: ignore[misc] + row[4] = "{}/{}".format(*row[4]) # type: ignore[misc] append(row_format.format(*row)) @@ -754,8 +758,8 @@ def append(s: str) -> None: summary_format = "{} took total {}s [min/index: {}, max/index: {}, mean: {}s, std: {}s]" for row in results[-2:]: - row[3] = "{}s/{}".format(*row[3]) - row[4] = "{}s/{}".format(*row[4]) + row[3] = "{}s/{}".format(*row[3]) # type: ignore[misc] + row[4] = "{}s/{}".format(*row[4]) # type: ignore[misc] del row[1] append(summary_format.format(*row)) print("".join(result)) diff --git a/ignite/contrib/handlers/tqdm_logger.py b/ignite/contrib/handlers/tqdm_logger.py index a0fe33f561e6..0c5e3e091e22 100644 --- a/ignite/contrib/handlers/tqdm_logger.py +++ b/ignite/contrib/handlers/tqdm_logger.py @@ -1,13 +1,12 @@ # -*- coding: utf-8 -*- import warnings -from enum import Enum -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, List, Optional, Union, cast import torch from ignite.contrib.handlers.base_logger import BaseLogger, BaseOutputHandler from ignite.engine import Engine, Events -from ignite.engine.events import CallableEventWithFilter, EventEnum +from ignite.engine.events import CallableEventWithFilter class ProgressBar(BaseLogger): @@ -106,7 +105,7 @@ def __init__( persist: bool = False, bar_format: str = "{desc}[{n_fmt}/{total_fmt}] {percentage:3.0f}%|{bar}{postfix} [{elapsed}<{remaining}]", **tqdm_kwargs: Any - ): + ) -> None: try: from tqdm.autonotebook import tqdm @@ -122,12 +121,12 @@ def __init__( self.bar_format = bar_format self.tqdm_kwargs = tqdm_kwargs - def _reset(self, pbar_total: int): + def _reset(self, pbar_total: int) -> None: self.pbar = self.pbar_cls( total=pbar_total, leave=self.persist, bar_format=self.bar_format, initial=1, **self.tqdm_kwargs ) - def _close(self, engine: Engine): + def _close(self, engine: Engine) -> None: if self.pbar is not None: # https://github.com/tqdm/notebook.py#L240-L250 # issue #1115 : notebook backend of tqdm checks if n < total (error or KeyboardInterrupt) @@ -138,12 +137,12 @@ def _close(self, engine: Engine): self.pbar = None @staticmethod - def _compare_lt(event1: EventEnum, event2: EventEnum): + def _compare_lt(event1: Events, event2: Events) -> bool: i1 = ProgressBar._events_order.index(event1) i2 = ProgressBar._events_order.index(event2) return i1 < i2 - def log_message(self, message: str): + def log_message(self, message: str) -> None: """ Logs a message, preserving the progress bar correct output format. @@ -154,14 +153,14 @@ def log_message(self, message: str): tqdm.write(message, file=self.tqdm_kwargs.get("file", None)) - def attach( + def attach( # type: ignore[override] self, engine: Engine, metric_names: Optional[str] = None, output_transform: Optional[Callable] = None, - event_name: Union[CallableEventWithFilter, Events] = Events.ITERATION_COMPLETED, + event_name: Events = Events.ITERATION_COMPLETED, closing_event_name: Events = Events.EPOCH_COMPLETED, - ): + ) -> None: """ Attaches the progress bar to an engine object. @@ -200,14 +199,16 @@ def attach( super(ProgressBar, self).attach(engine, log_handler, event_name) engine.add_event_handler(closing_event_name, self._close) - def attach_opt_params_handler(self, engine: Engine, event_name: Union[str, EventEnum], *args: Any, **kwargs: Any): + def attach_opt_params_handler( + self, engine: Engine, event_name: Union[str, Events], *args: Any, **kwargs: Any + ) -> None: """Intentionally empty""" pass - def _create_output_handler(self, *args: Any, **kwargs: Any): + def _create_output_handler(self, *args: Any, **kwargs: Any) -> "_OutputHandler": return _OutputHandler(*args, **kwargs) - def _create_opt_params_handler(self, *args: Any, **kwargs: Any): + def _create_opt_params_handler(self, *args: Any, **kwargs: Any) -> Callable: """Intentionally empty""" pass @@ -232,10 +233,10 @@ class _OutputHandler(BaseOutputHandler): def __init__( self, description: str, - metric_names: Optional[str] = None, + metric_names: Optional[Union[str, List[str]]] = None, output_transform: Optional[Callable] = None, - closing_event_name: EventEnum = Events.EPOCH_COMPLETED, - ): + closing_event_name: Events = Events.EPOCH_COMPLETED, + ) -> None: if metric_names is None and output_transform is None: # This helps to avoid 'Either metric_names or output_transform should be defined' of BaseOutputHandler metric_names = [] @@ -243,14 +244,14 @@ def __init__( self.closing_event_name = closing_event_name @staticmethod - def get_max_number_events(event_name: Union[CallableEventWithFilter, Enum], engine: Engine): + def get_max_number_events(event_name: Union[str, Events], engine: Engine) -> int: if event_name in (Events.ITERATION_STARTED, Events.ITERATION_COMPLETED): - return engine.state.epoch_length + return cast(int, engine.state.epoch_length) if event_name in (Events.EPOCH_STARTED, Events.EPOCH_COMPLETED): - return engine.state.max_epochs + return cast(int, engine.state.max_epochs) return 1 - def __call__(self, engine: Engine, logger: ProgressBar, event_name: Union[CallableEventWithFilter, Enum]): + def __call__(self, engine: Engine, logger: ProgressBar, event_name: Union[str, Events]) -> None: pbar_total = self.get_max_number_events(event_name, engine) if logger.pbar is None: @@ -264,7 +265,7 @@ def __call__(self, engine: Engine, logger: ProgressBar, event_name: Union[Callab if max_num_of_closing_events > 1: global_step = engine.state.get_event_attrib_value(self.closing_event_name) desc += " [{}/{}]".format(global_step, max_num_of_closing_events) - logger.pbar.set_description(desc) + logger.pbar.set_description(desc) # type: ignore[attr-defined] metrics = self._setup_output_metrics(engine) @@ -283,9 +284,9 @@ def __call__(self, engine: Engine, logger: ProgressBar, event_name: Union[Callab rendered_metrics[key] = value if rendered_metrics: - logger.pbar.set_postfix(**rendered_metrics) + logger.pbar.set_postfix(**rendered_metrics) # type: ignore[attr-defined] global_step = engine.state.get_event_attrib_value(event_name) if pbar_total is not None: global_step = (global_step - 1) % pbar_total + 1 - logger.pbar.update(global_step - logger.pbar.n) + logger.pbar.update(global_step - logger.pbar.n) # type: ignore[attr-defined] diff --git a/ignite/contrib/handlers/trains_logger.py b/ignite/contrib/handlers/trains_logger.py index 49bc510542f9..7c00da41c87d 100644 --- a/ignite/contrib/handlers/trains_logger.py +++ b/ignite/contrib/handlers/trains_logger.py @@ -5,7 +5,7 @@ from collections import defaultdict from datetime import datetime from enum import Enum -from typing import Any, Callable, List, Mapping, Optional, Type +from typing import Any, Callable, DefaultDict, List, Mapping, Optional, Tuple, Type, Union import torch from torch.nn import Module @@ -19,7 +19,7 @@ BaseWeightsHistHandler, BaseWeightsScalarHandler, ) -from ignite.engine import Engine, EventEnum +from ignite.engine import Engine, Events from ignite.handlers import global_step_from_engine from ignite.handlers.checkpoint import DiskSaver @@ -119,7 +119,7 @@ class TrainsLogger(BaseLogger): """ - def __init__(self, *_, **kwargs: Any): + def __init__(self, *_: Any, **kwargs: Any) -> None: try: from trains import Task from trains.binding.frameworks.tensorflow_bind import WeightsGradientHistHelper @@ -135,15 +135,15 @@ def __init__(self, *_, **kwargs: Any): warnings.warn("TrainsSaver: running in bypass mode") class _Stub(object): - def __call__(self, *_, **__): + def __call__(self, *_: Any, **__: Any) -> "_Stub": return self - def __getattr__(self, attr): + def __getattr__(self, attr: str) -> "_Stub": if attr in ("name", "id"): - return "" + return "" # type: ignore[return-value] return self - def __setattr__(self, attr, val): + def __setattr__(self, attr: str, val: Any) -> None: pass self._task = _Stub() @@ -183,13 +183,13 @@ def bypass_mode(cls) -> bool: """ return getattr(cls, "_bypass", bool(os.environ.get("CI"))) - def close(self): + def close(self) -> None: self.trains_logger.flush() - def _create_output_handler(self, *args: Any, **kwargs: Any): + def _create_output_handler(self, *args: Any, **kwargs: Any) -> "OutputHandler": return OutputHandler(*args, **kwargs) - def _create_opt_params_handler(self, *args: Any, **kwargs: Any): + def _create_opt_params_handler(self, *args: Any, **kwargs: Any) -> "OptimizerParamsHandler": return OptimizerParamsHandler(*args, **kwargs) @@ -293,17 +293,17 @@ def __init__( metric_names: Optional[List[str]] = None, output_transform: Optional[Callable] = None, global_step_transform: Optional[Callable] = None, - ): + ) -> None: super(OutputHandler, self).__init__(tag, metric_names, output_transform, global_step_transform) - def __call__(self, engine: Engine, logger: TrainsLogger, event_name: EventEnum): + def __call__(self, engine: Engine, logger: TrainsLogger, event_name: Union[str, Events]) -> None: if not isinstance(logger, TrainsLogger): raise RuntimeError("Handler OutputHandler works only with TrainsLogger") metrics = self._setup_output_metrics(engine) - global_step = self.global_step_transform(engine, event_name) + global_step = self.global_step_transform(engine, event_name) # type: ignore[misc] if not isinstance(global_step, int): raise TypeError( @@ -359,10 +359,10 @@ class OptimizerParamsHandler(BaseOptimizerParamsHandler): tag (str, optional): common title for all produced plots. For example, "generator" """ - def __init__(self, optimizer: Optimizer, param_name: str = "lr", tag: Optional[str] = None): + def __init__(self, optimizer: Optimizer, param_name: str = "lr", tag: Optional[str] = None) -> None: super(OptimizerParamsHandler, self).__init__(optimizer, param_name, tag) - def __call__(self, engine: Engine, logger: TrainsLogger, event_name: EventEnum): + def __call__(self, engine: Engine, logger: TrainsLogger, event_name: Union[str, Events]) -> None: if not isinstance(logger, TrainsLogger): raise RuntimeError("Handler OptimizerParamsHandler works only with TrainsLogger") @@ -410,10 +410,10 @@ class WeightsScalarHandler(BaseWeightsScalarHandler): """ - def __init__(self, model: Module, reduction: Callable = torch.norm, tag: Optional[str] = None): + def __init__(self, model: Module, reduction: Callable = torch.norm, tag: Optional[str] = None) -> None: super(WeightsScalarHandler, self).__init__(model, reduction, tag=tag) - def __call__(self, engine: Engine, logger: TrainsLogger, event_name: EventEnum): + def __call__(self, engine: Engine, logger: TrainsLogger, event_name: Union[str, Events]) -> None: if not isinstance(logger, TrainsLogger): raise RuntimeError("Handler WeightsScalarHandler works only with TrainsLogger") @@ -462,10 +462,10 @@ class WeightsHistHandler(BaseWeightsHistHandler): """ - def __init__(self, model: Module, tag: Optional[str] = None): + def __init__(self, model: Module, tag: Optional[str] = None) -> None: super(WeightsHistHandler, self).__init__(model, tag=tag) - def __call__(self, engine: Engine, logger: TrainsLogger, event_name: EventEnum): + def __call__(self, engine: Engine, logger: TrainsLogger, event_name: Union[str, Events]) -> None: if not isinstance(logger, TrainsLogger): raise RuntimeError("Handler 'WeightsHistHandler' works only with TrainsLogger") @@ -517,10 +517,10 @@ class GradsScalarHandler(BaseWeightsScalarHandler): """ - def __init__(self, model: Module, reduction: Callable = torch.norm, tag: Optional[str] = None): + def __init__(self, model: Module, reduction: Callable = torch.norm, tag: Optional[str] = None) -> None: super(GradsScalarHandler, self).__init__(model, reduction, tag=tag) - def __call__(self, engine: Engine, logger: TrainsLogger, event_name: EventEnum): + def __call__(self, engine: Engine, logger: TrainsLogger, event_name: Union[str, Events]) -> None: if not isinstance(logger, TrainsLogger): raise RuntimeError("Handler GradsScalarHandler works only with TrainsLogger") @@ -568,10 +568,10 @@ class GradsHistHandler(BaseWeightsHistHandler): """ - def __init__(self, model: Module, tag: Optional[str] = None): + def __init__(self, model: Module, tag: Optional[str] = None) -> None: super(GradsHistHandler, self).__init__(model, tag=tag) - def __call__(self, engine: Engine, logger: TrainsLogger, event_name: EventEnum): + def __call__(self, engine: Engine, logger: TrainsLogger, event_name: Union[str, Events]) -> None: if not isinstance(logger, TrainsLogger): raise RuntimeError("Handler 'GradsHistHandler' works only with TrainsLogger") @@ -633,8 +633,13 @@ class TrainsSaver(DiskSaver): """ def __init__( - self, logger: TrainsLogger = None, output_uri: str = None, dirname: str = None, *args: Any, **kwargs: Any - ): + self, + logger: Optional[TrainsLogger] = None, + output_uri: Optional[str] = None, + dirname: Optional[str] = None, + *args: Any, + **kwargs: Any + ) -> None: self._setup_check_trains(logger, output_uri) @@ -645,7 +650,7 @@ def __init__( prefix="ignite_checkpoints_{}".format(datetime.now().strftime("%Y_%m_%d_%H_%M_%S_")) ) if idist.get_world_size() > 1: - dirname = idist.all_gather(dirname)[0] + dirname = idist.all_gather(dirname)[0] # type: ignore[index, assignment] warnings.warn("TrainsSaver created a temporary checkpoints directory: {}".format(dirname)) idist.barrier() @@ -654,12 +659,12 @@ def __init__( if "atomic" not in kwargs: kwargs["atomic"] = False - self._checkpoint_slots = defaultdict(list) + self._checkpoint_slots = defaultdict(list) # type: DefaultDict[Union[str, Tuple[str, str]], List[Any]] - super(TrainsSaver, self).__init__(dirname=dirname, *args, **kwargs) + super(TrainsSaver, self).__init__(dirname=dirname, *args, **kwargs) # type: ignore[misc] @idist.one_rank_only() - def _setup_check_trains(self, logger: TrainsLogger, output_uri: str): + def _setup_check_trains(self, logger: TrainsLogger, output_uri: str) -> None: try: from trains import Task except ImportError: @@ -690,7 +695,7 @@ def __init__( filename: str, basename: str, metadata: Optional[Mapping] = None, - ): + ) -> None: self._callback_type = callback_type self._slots = slots self._checkpoint_key = str(checkpoint_key) @@ -698,8 +703,8 @@ def __init__( self._basename = basename self._metadata = metadata - def pre_callback(self, action: str, model_info: Any): - if action != self._callback_type.save: + def pre_callback(self, action: str, model_info: Any) -> Any: + if action != self._callback_type.save: # type: ignore[attr-defined] return model_info try: @@ -713,8 +718,8 @@ def pre_callback(self, action: str, model_info: Any): model_info.local_model_id = "{}:{}".format(self._checkpoint_key, model_info.upload_filename) return model_info - def post_callback(self, action: str, model_info: Any): - if action != self._callback_type.save: + def post_callback(self, action: str, model_info: Any) -> Any: + if action != self._callback_type.save: # type: ignore[attr-defined] return model_info model_info.model.name = "{}: {}".format(model_info.task.name, self._filename) @@ -742,7 +747,7 @@ def __call__(self, checkpoint: Mapping, filename: str, metadata: Optional[Mappin ) try: - basename = metadata["basename"] + basename = metadata["basename"] # type: ignore[index] except (TypeError, KeyError): warnings.warn("Checkpoint metadata missing or basename cannot be found") basename = "checkpoint" @@ -786,6 +791,8 @@ def get_local_copy(self, filename: str) -> Optional[str]: return artifact.get_local_copy() self._task.get_logger().report_text("Can not find artifact {}".format(filename)) + return None + @idist.one_rank_only() def remove(self, filename: str) -> None: super(TrainsSaver, self).remove(filename) diff --git a/ignite/contrib/handlers/visdom_logger.py b/ignite/contrib/handlers/visdom_logger.py index 010ab3b05493..7b737c6bea85 100644 --- a/ignite/contrib/handlers/visdom_logger.py +++ b/ignite/contrib/handlers/visdom_logger.py @@ -1,7 +1,7 @@ import numbers import os import warnings -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union, cast import torch import torch.nn as nn @@ -13,7 +13,7 @@ BaseOutputHandler, BaseWeightsScalarHandler, ) -from ignite.engine import Engine +from ignite.engine import Engine, Events from ignite.handlers import global_step_from_engine __all__ = [ @@ -170,7 +170,7 @@ def __init__( ) if server is None: - server = os.environ.get("VISDOM_SERVER_URL", "localhost") + server = cast(str, os.environ.get("VISDOM_SERVER_URL", "localhost")) if port is None: port = int(os.environ.get("VISDOM_PORT", 8097)) @@ -185,37 +185,39 @@ def __init__( self.vis = visdom.Visdom(server=server, port=port, raise_exceptions=raise_exceptions, **kwargs) - if not self.vis.offline and not self.vis.check_connection(): + if not self.vis.offline and not self.vis.check_connection(): # type: ignore[attr-defined] raise RuntimeError( "Failed to connect to Visdom server at {}. Did you run python -m visdom.server ?".format(server) ) - self.executor = _DummyExecutor() + self.executor = _DummyExecutor() # type: Union[_DummyExecutor, "ThreadPoolExecutor"] if num_workers > 0: from concurrent.futures import ThreadPoolExecutor self.executor = ThreadPoolExecutor(max_workers=num_workers) - def _save(self): - self.vis.save([self.vis.env]) + def _save(self) -> None: + self.vis.save([self.vis.env]) # type: ignore[attr-defined] - def close(self): + def close(self) -> None: self.executor.shutdown() - self.vis = None + self.vis.close() - def _create_output_handler(self, *args: Any, **kwargs: Any): + def _create_output_handler(self, *args: Any, **kwargs: Any) -> "OutputHandler": return OutputHandler(*args, **kwargs) - def _create_opt_params_handler(self, *args: Any, **kwargs: Any): + def _create_opt_params_handler(self, *args: Any, **kwargs: Any) -> "OptimizerParamsHandler": return OptimizerParamsHandler(*args, **kwargs) class _BaseVisDrawer: - def __init__(self, show_legend: bool = False): - self.windows = {} + def __init__(self, show_legend: bool = False) -> None: + self.windows = {} # type: Dict[str, Any] self.show_legend = show_legend - def add_scalar(self, logger: VisdomLogger, k: str, v: Union[str, float], event_name: Any, global_step: int): + def add_scalar( + self, logger: VisdomLogger, k: str, v: Union[str, float, torch.Tensor], event_name: Any, global_step: int + ) -> None: """ Helper method to log a scalar with VisdomLogger. @@ -240,7 +242,7 @@ def add_scalar(self, logger: VisdomLogger, k: str, v: Union[str, float], event_n kwargs = { "X": [global_step], "Y": [v], - "env": logger.vis.env, + "env": logger.vis.env, # type: ignore[attr-defined] "win": self.windows[k]["win"], "update": update, "opts": self.windows[k]["opts"], @@ -346,18 +348,18 @@ def __init__( output_transform: Optional[Callable] = None, global_step_transform: Optional[Callable] = None, show_legend: bool = False, - ): + ) -> None: super(OutputHandler, self).__init__(tag, metric_names, output_transform, global_step_transform) _BaseVisDrawer.__init__(self, show_legend=show_legend) - def __call__(self, engine: Engine, logger: VisdomLogger, event_name: Any): + def __call__(self, engine: Engine, logger: VisdomLogger, event_name: Union[str, Events]) -> None: if not isinstance(logger, VisdomLogger): raise RuntimeError("Handler 'OutputHandler' works only with VisdomLogger") metrics = self._setup_output_metrics(engine) - global_step = self.global_step_transform(engine, event_name) + global_step = self.global_step_transform(engine, event_name) # type: ignore[misc] if not isinstance(global_step, int): raise TypeError( @@ -367,13 +369,13 @@ def __call__(self, engine: Engine, logger: VisdomLogger, event_name: Any): for key, value in metrics.items(): - values = [] + values = [] # type: List[Union[float, torch.Tensor]] keys = [] if isinstance(value, numbers.Number) or isinstance(value, torch.Tensor) and value.ndimension() == 0: - values.append(value) + values.append(value) # type: ignore[arg-type] keys.append(key) elif isinstance(value, torch.Tensor) and value.ndimension() == 1: - values = value + values = value # type: ignore[assignment] keys = ["{}/{}".format(key, i) for i in range(len(value))] else: warnings.warn("VisdomLogger output_handler can not log " "metrics value type {}".format(type(value))) @@ -420,11 +422,11 @@ class OptimizerParamsHandler(BaseOptimizerParamsHandler, _BaseVisDrawer): def __init__( self, optimizer: Optimizer, param_name: str = "lr", tag: Optional[str] = None, show_legend: bool = False, - ): + ) -> None: super(OptimizerParamsHandler, self).__init__(optimizer, param_name, tag) _BaseVisDrawer.__init__(self, show_legend=show_legend) - def __call__(self, engine: Engine, logger: VisdomLogger, event_name: Any): + def __call__(self, engine: Engine, logger: VisdomLogger, event_name: Union[str, Events]) -> None: if not isinstance(logger, VisdomLogger): raise RuntimeError("Handler OptimizerParamsHandler works only with VisdomLogger") @@ -471,11 +473,11 @@ class WeightsScalarHandler(BaseWeightsScalarHandler, _BaseVisDrawer): def __init__( self, model: nn.Module, reduction: Callable = torch.norm, tag: Optional[str] = None, show_legend: bool = False, - ): + ) -> None: super(WeightsScalarHandler, self).__init__(model, reduction, tag=tag) _BaseVisDrawer.__init__(self, show_legend=show_legend) - def __call__(self, engine: Engine, logger: VisdomLogger, event_name: Any): + def __call__(self, engine: Engine, logger: VisdomLogger, event_name: Union[str, Events]) -> None: if not isinstance(logger, VisdomLogger): raise RuntimeError("Handler 'WeightsScalarHandler' works only with VisdomLogger") @@ -522,11 +524,11 @@ class GradsScalarHandler(BaseWeightsScalarHandler, _BaseVisDrawer): def __init__( self, model: nn.Module, reduction: Callable = torch.norm, tag: Optional[str] = None, show_legend: bool = False, - ): + ) -> None: super(GradsScalarHandler, self).__init__(model, reduction, tag) _BaseVisDrawer.__init__(self, show_legend=show_legend) - def __call__(self, engine: Engine, logger: VisdomLogger, event_name: Any): + def __call__(self, engine: Engine, logger: VisdomLogger, event_name: Union[str, Events]) -> None: if not isinstance(logger, VisdomLogger): raise RuntimeError("Handler 'GradsScalarHandler' works only with VisdomLogger") @@ -543,17 +545,17 @@ def __call__(self, engine: Engine, logger: VisdomLogger, event_name: Any): class _DummyExecutor: class _DummyFuture: - def __init__(self, result: Any): + def __init__(self, result: Any) -> None: self._output = result - def result(self): + def result(self) -> Any: return self._output - def __init__(self, *args: Any, **kwargs: Any): + def __init__(self, *args: Any, **kwargs: Any) -> None: pass - def submit(self, fn: Callable, **kwargs: Any): + def submit(self, fn: Callable, **kwargs: Any) -> "_DummyFuture": return _DummyExecutor._DummyFuture(fn(**kwargs)) - def shutdown(self, *args: Any, **kwargs: Any): + def shutdown(self, *args: Any, **kwargs: Any) -> None: pass diff --git a/ignite/contrib/handlers/wandb_logger.py b/ignite/contrib/handlers/wandb_logger.py index de1ac2e1a4a0..a101f4d1f56c 100644 --- a/ignite/contrib/handlers/wandb_logger.py +++ b/ignite/contrib/handlers/wandb_logger.py @@ -1,10 +1,9 @@ -from enum import Enum from typing import Any, Callable, List, Optional, Union from torch.optim import Optimizer from ignite.contrib.handlers.base_logger import BaseLogger, BaseOptimizerParamsHandler, BaseOutputHandler -from ignite.engine import CallableEventWithFilter, Engine +from ignite.engine import Engine, Events from ignite.handlers import global_step_from_engine __all__ = ["WandBLogger", "OutputHandler", "OptimizerParamsHandler", "global_step_from_engine"] @@ -117,7 +116,7 @@ def score_function(engine): evaluator.add_event_handler(Events.COMPLETED, model_checkpoint, {'model': model}) """ - def __init__(self, *args: Any, **kwargs: Any): + def __init__(self, *args: Any, **kwargs: Any) -> None: try: import wandb @@ -130,13 +129,13 @@ def __init__(self, *args: Any, **kwargs: Any): if kwargs.get("init", True): wandb.init(*args, **kwargs) - def __getattr__(self, attr: Any): - return getattr(self._wandb, attr) + def __getattr__(self, attr: Any) -> Any: + return getattr(self._wandb, attr) # type: ignore[misc] - def _create_output_handler(self, *args: Any, **kwargs: Any): + def _create_output_handler(self, *args: Any, **kwargs: Any) -> "OutputHandler": return OutputHandler(*args, **kwargs) - def _create_opt_params_handler(self, *args: Any, **kwargs: Any): + def _create_opt_params_handler(self, *args: Any, **kwargs: Any) -> "OptimizerParamsHandler": return OptimizerParamsHandler(*args, **kwargs) @@ -252,16 +251,16 @@ def __init__( output_transform: Optional[Callable] = None, global_step_transform: Optional[Callable] = None, sync: Optional[bool] = None, - ): + ) -> None: super().__init__(tag, metric_names, output_transform, global_step_transform) self.sync = sync - def __call__(self, engine: Engine, logger: WandBLogger, event_name: Union[CallableEventWithFilter, Enum]): + def __call__(self, engine: Engine, logger: WandBLogger, event_name: Union[str, Events]) -> None: if not isinstance(logger, WandBLogger): raise RuntimeError("Handler '{}' works only with WandBLogger.".format(self.__class__.__name__)) - global_step = self.global_step_transform(engine, event_name) + global_step = self.global_step_transform(engine, event_name) # type: ignore[misc] if not isinstance(global_step, int): raise TypeError( "global_step must be int, got {}." @@ -319,11 +318,11 @@ class OptimizerParamsHandler(BaseOptimizerParamsHandler): def __init__( self, optimizer: Optimizer, param_name: str = "lr", tag: Optional[str] = None, sync: Optional[bool] = None, - ): + ) -> None: super(OptimizerParamsHandler, self).__init__(optimizer, param_name, tag) self.sync = sync - def __call__(self, engine: Engine, logger: WandBLogger, event_name: Union[CallableEventWithFilter, Enum]): + def __call__(self, engine: Engine, logger: WandBLogger, event_name: Union[str, Events]) -> None: if not isinstance(logger, WandBLogger): raise RuntimeError("Handler OptimizerParamsHandler works only with WandBLogger") diff --git a/ignite/engine/events.py b/ignite/engine/events.py index cd8f14da7d62..e7782a2646cd 100644 --- a/ignite/engine/events.py +++ b/ignite/engine/events.py @@ -362,7 +362,7 @@ class State: Events.EPOCH_COMPLETED: "epoch", Events.STARTED: "epoch", Events.COMPLETED: "epoch", - } + } # type: Dict[Union[str, "Events"], str] def __init__(self, **kwargs: Any) -> None: self.iteration = 0 @@ -390,7 +390,7 @@ def _update_attrs(self) -> None: if not hasattr(self, value): setattr(self, value, 0) - def get_event_attrib_value(self, event_name: Events) -> int: + def get_event_attrib_value(self, event_name: Union[str, Events]) -> int: if event_name not in State.event_to_attr: raise RuntimeError("Unknown event name '{}'".format(event_name)) return getattr(self, State.event_to_attr[event_name]) diff --git a/mypy.ini b/mypy.ini index 62b36053b343..64a59e66bf43 100644 --- a/mypy.ini +++ b/mypy.ini @@ -24,21 +24,41 @@ warn_unreachable = False warn_unused_configs = True warn_unused_ignores = True -[mypy-ignite.contrib.handlers.*] +[mypy-horovod.*] +ignore_missing_imports = True -ignore_errors = True +[mypy-matplotlib.*] +ignore_missing_imports = True -[mypy-horovod.*] +[mypy-mlflow.*] +ignore_missing_imports = True + +[mypy-neptune.*] ignore_missing_imports = True [mypy-numpy.*] ignore_missing_imports = True +[mypy-pandas.*] +ignore_missing_imports = True + [mypy-sklearn.*] ignore_missing_imports = True +[mypy-polyaxon_client.*] +ignore_missing_imports = True + [mypy-pynvml.*] ignore_missing_imports = True +[mypy-tensorboardX.*] +ignore_missing_imports = True + [mypy-torch_xla.*] ignore_missing_imports = True + +[mypy-trains.*] +ignore_missing_imports = True + +[mypy-tqdm.*] +ignore_missing_imports = True From dd490793eeaf09b0012505aabd30b6ab969a6810 Mon Sep 17 00:00:00 2001 From: gruebel Date: Thu, 3 Dec 2020 23:00:52 +0100 Subject: [PATCH 2/7] Revert changes in param_scheduler --- ignite/contrib/handlers/param_scheduler.py | 202 +++++++++------------ mypy.ini | 3 + 2 files changed, 89 insertions(+), 116 deletions(-) diff --git a/ignite/contrib/handlers/param_scheduler.py b/ignite/contrib/handlers/param_scheduler.py index f08a5bfdf748..024d477c82ae 100644 --- a/ignite/contrib/handlers/param_scheduler.py +++ b/ignite/contrib/handlers/param_scheduler.py @@ -6,7 +6,7 @@ from collections import OrderedDict from copy import copy from pathlib import Path -from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union, cast +from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union import torch from torch.optim.lr_scheduler import _LRScheduler @@ -40,7 +40,7 @@ def __init__( param_name: str, save_history: bool = False, param_group_index: Optional[int] = None, - ) -> None: + ): if not ( isinstance(optimizer, Optimizer) @@ -54,11 +54,11 @@ def __init__( self.optimizer = optimizer self.param_group_index = param_group_index self.param_name = param_name + self.save_history = save_history self.event_index = 0 - self._save_history = save_history self._state_attrs = ["event_index", "param_name", "save_history", "param_group_index"] - def __call__(self, engine: Optional[Engine], name: Optional[str] = None) -> None: + def __call__(self, engine: Engine, name: Optional[str] = None): value = self.get_param() @@ -79,29 +79,21 @@ def __call__(self, engine: Optional[Engine], name: Optional[str] = None) -> None if name is None: name = self.param_name - if self.save_history and engine: - if not hasattr(engine.state, "param_history") or engine.state.param_history is None: # type: ignore + if self.save_history: + if not hasattr(engine.state, "param_history") or engine.state.param_history is None: setattr(engine.state, "param_history", {}) - engine.state.param_history.setdefault(name, []) # type: ignore[attr-defined] + engine.state.param_history.setdefault(name, []) values = [pg[self.param_name] for pg in self.optimizer_param_groups] - engine.state.param_history[name].append(values) # type: ignore[attr-defined] + engine.state.param_history[name].append(values) self.event_index += 1 @property - def optimizer_param_groups(self) -> List[Dict[str, Any]]: + def optimizer_param_groups(self): if self.param_group_index is None: return self.optimizer.param_groups return [self.optimizer.param_groups[self.param_group_index]] - @property - def save_history(self) -> bool: - return self._save_history - - @save_history.setter - def save_history(self, value: bool) -> None: - self._save_history = value - - def state_dict(self) -> Dict[str, Any]: + def state_dict(self): """Returns a dictionary containing a whole state of ParamScheduler. Returns: @@ -117,7 +109,7 @@ def state_dict(self) -> Dict[str, Any]: destination[name] = copy(val) return destination - def load_state_dict(self, state_dict: Mapping) -> None: + def load_state_dict(self, state_dict: Mapping): """Copies parameters from :attr:`state_dict` into this ParamScheduler. Args: @@ -150,7 +142,7 @@ def get_param(self) -> Union[List[float], float]: pass @classmethod - def simulate_values(cls, num_events: int, **scheduler_kwargs: Any) -> List[List[int]]: + def simulate_values(cls, num_events: int, **scheduler_kwargs: Any): """Method to simulate scheduled values during `num_events` events. Args: @@ -186,7 +178,7 @@ def simulate_values(cls, num_events: int, **scheduler_kwargs: Any) -> List[List[ return values @classmethod - def plot_values(cls, num_events: int, **scheduler_kwargs: Mapping) -> Any: + def plot_values(cls, num_events: int, **scheduler_kwargs: Mapping): """Method to plot simulated scheduled values during `num_events` events. This class requires `matplotlib package `_ to be installed: @@ -293,10 +285,10 @@ def __init__( "end_value_mult", ] - def __call__(self, engine: Optional[Engine], name: Optional[str] = None) -> None: + def __call__(self, engine: Engine, name: Optional[str] = None): if self.event_index != 0 and self.event_index % self.cycle_size == 0: self.event_index = 0 - self.cycle_size = int(self.cycle_size * self.cycle_mult) + self.cycle_size *= self.cycle_mult self.cycle += 1 self.start_value *= self.start_value_mult self.end_value *= self.end_value_mult @@ -343,7 +335,7 @@ class LinearCyclicalScheduler(CyclicalScheduler): # """ - def get_param(self) -> float: + def get_param(self): cycle_progress = self.event_index / self.cycle_size return self.end_value + (self.start_value - self.end_value) * abs(cycle_progress - 0.5) * 2 @@ -409,7 +401,7 @@ class CosineAnnealingScheduler(CyclicalScheduler): Applications of Computer Vision (WACV), 2017 IEEE Winter Conference on. IEEE, 2017 """ - def get_param(self) -> float: + def get_param(self): """Method to get current optimizer's parameter value """ cycle_progress = self.event_index / self.cycle_size @@ -449,7 +441,7 @@ class ConcatScheduler(ParamScheduler): """ - def __init__(self, schedulers: List[ParamScheduler], durations: List[int], save_history: bool = False) -> None: + def __init__(self, schedulers: List[ParamScheduler], durations: List[int], save_history: bool = False): if not isinstance(schedulers, Sequence): raise TypeError("Argument schedulers should be a sequence, but given {}".format(schedulers)) @@ -482,17 +474,17 @@ def __init__(self, schedulers: List[ParamScheduler], durations: List[int], save_ self.schedulers = schedulers self.durations = durations - tmp_optimizers = [s.optimizer for s in self.schedulers] - tmps_list_optimizers = [s if isinstance(s, list) else [s] for s in tmp_optimizers] - param_optimizers = list(itertools.chain(*tmps_list_optimizers)) + param_optimizers = [s.optimizer for s in self.schedulers] + param_optimizers = [s if isinstance(s, list) else [s] for s in param_optimizers] + param_optimizers = list(itertools.chain(*param_optimizers)) optimizer = list(set(param_optimizers)) if len(optimizer) != 1: raise ValueError("schedulers should be related to same optimizer") - tmp_param_names = [s.param_name for s in self.schedulers] - tmp_list_param_names = [s if isinstance(s, list) else [s] for s in tmp_param_names] - param_names = list(itertools.chain(*tmp_list_param_names)) + param_names = [s.param_name for s in self.schedulers] + param_names = [s if isinstance(s, list) else [s] for s in param_names] + param_names = list(itertools.chain(*param_names)) param_name = list(set(param_names)) if len(param_name) != 1: @@ -507,12 +499,12 @@ def __init__(self, schedulers: List[ParamScheduler], durations: List[int], save_ ) self._scheduler_index = 0 - # self._current_scheduler = None - # self._current_duration = None + self._current_scheduler = None + self._current_duration = None self._setup_scheduler() self._state_attrs += ["_current_duration", "durations", "_scheduler_index"] - def state_dict(self) -> Dict[str, Any]: + def state_dict(self): """Returns a dictionary containing a whole state of ConcatScheduler. Returns: @@ -526,7 +518,7 @@ def state_dict(self) -> Dict[str, Any]: state_dict["schedulers"].append(s.state_dict()) return state_dict - def load_state_dict(self, state_dict: Mapping) -> None: + def load_state_dict(self, state_dict: Mapping): """Copies parameters from :attr:`state_dict` into this ConcatScheduler. Args: @@ -553,13 +545,13 @@ def load_state_dict(self, state_dict: Mapping) -> None: super(ConcatScheduler, self).load_state_dict(state_dict) self._setup_scheduler() - def _setup_scheduler(self) -> None: + def _setup_scheduler(self): self._current_scheduler = self.schedulers[self._scheduler_index] self._current_duration = ( self.durations[self._scheduler_index] if self._scheduler_index < len(self.durations) else -1 ) - def __call__(self, engine: Optional[Engine], name: Optional[str] = None) -> None: + def __call__(self, engine: Engine, name: Optional[str] = None): if self._current_duration == 0: self._scheduler_index += 1 self._setup_scheduler() @@ -567,32 +559,32 @@ def __call__(self, engine: Optional[Engine], name: Optional[str] = None) -> None self._current_duration -= 1 @property - def optimizer_param_groups(self) -> List[Dict[str, Any]]: + def optimizer_param_groups(self): # We need to setup optimizer_param_groups as property # to synchonize with the latest _current_scheduler and its internal optimizer_param_groups return self._current_scheduler.optimizer_param_groups @property - def save_history(self) -> bool: + def save_history(self): return self._current_scheduler.save_history @save_history.setter - def save_history(self, value: bool) -> None: + def save_history(self, value: bool): for s in self.schedulers: s.save_history = value - def get_param(self) -> Union[List[float], float]: + def get_param(self): return self._current_scheduler.get_param() @classmethod - def simulate_values( # type: ignore[override] + def simulate_values( cls, num_events: int, schedulers: List[ParamScheduler], durations: List[int], param_names: Optional[Union[List[str], Tuple[str]]] = None, **kwargs: Any - ) -> List[List[int]]: + ): """Method to simulate scheduled values during num_events events. Args: @@ -614,15 +606,15 @@ def simulate_values( # type: ignore[override] "Argument param_names should be list or tuple of strings, but given {}".format(param_names) ) - tmp_param_optimizers = [s.optimizer for s in schedulers] - tmp_list_param_optimizers = [s if isinstance(s, list) else [s] for s in tmp_param_optimizers] - param_optimizers = list(itertools.chain(*tmp_list_param_optimizers)) + param_optimizers = [s.optimizer for s in schedulers] + param_optimizers = [s if isinstance(s, list) else [s] for s in param_optimizers] + param_optimizers = list(itertools.chain(*param_optimizers)) - tmp_optimizer = list(set(param_optimizers)) - if len(tmp_optimizer) != 1: + optimizer = list(set(param_optimizers)) + if len(optimizer) != 1: raise ValueError("schedulers should be related to same optimizer") - optimizer = tmp_optimizer[0] + optimizer = optimizer[0] # This scheduler uses `ParamScheduler` which # should be replicated in order to simulate LR values and @@ -640,9 +632,7 @@ def simulate_values( # type: ignore[override] s.save_history = False output = [] - scheduler = cls( # type: ignore[call-arg] - schedulers=schedulers, save_history=False, durations=durations, **kwargs - ) + scheduler = cls(schedulers=schedulers, save_history=False, durations=durations, **kwargs) if param_names is None: param_names = [scheduler.param_name] for i in range(num_events): @@ -685,7 +675,7 @@ class LRScheduler(ParamScheduler): trainer.add_event_handler(Events.ITERATION_COMPLETED, scheduler) """ - def __init__(self, lr_scheduler: _LRScheduler, save_history: bool = False) -> None: + def __init__(self, lr_scheduler: _LRScheduler, save_history=False): if not isinstance(lr_scheduler, _LRScheduler): raise TypeError( @@ -695,32 +685,28 @@ def __init__(self, lr_scheduler: _LRScheduler, save_history: bool = False) -> No self.lr_scheduler = lr_scheduler super(LRScheduler, self).__init__( - optimizer=self.lr_scheduler.optimizer, # type: ignore[attr-defined] - param_name="lr", - save_history=save_history, + optimizer=self.lr_scheduler.optimizer, param_name="lr", save_history=save_history ) self._state_attrs += ["lr_scheduler"] - def __call__(self, engine: Optional[Engine], name: Optional[str] = None) -> None: - self.lr_scheduler.last_epoch += 1 # type: ignore[attr-defined] + def __call__(self, engine: Engine, name: Optional[str] = None): + self.lr_scheduler.last_epoch += 1 super(LRScheduler, self).__call__(engine, name) def get_param(self) -> Union[float, List[float]]: """Method to get current optimizer's parameter value """ # Emulate context manager for pytorch>=1.4 - self.lr_scheduler._get_lr_called_within_step = True # type: ignore[attr-defined] - lr_list = cast(List[float], self.lr_scheduler.get_lr()) - self.lr_scheduler._get_lr_called_within_step = False # type: ignore[attr-defined] + self.lr_scheduler._get_lr_called_within_step = True + lr_list = self.lr_scheduler.get_lr() + self.lr_scheduler._get_lr_called_within_step = False if len(lr_list) == 1: return lr_list[0] else: return lr_list @classmethod - def simulate_values( # type: ignore[override] - cls, num_events: int, lr_scheduler: _LRScheduler, **kwargs: Any - ) -> List[List[int]]: + def simulate_values(cls, num_events: int, lr_scheduler: _LRScheduler, **kwargs: Any): """Method to simulate scheduled values during num_events events. Args: @@ -743,14 +729,11 @@ def simulate_values( # type: ignore[override] # not perturb original scheduler. with tempfile.TemporaryDirectory() as tmpdirname: cache_filepath = Path(tmpdirname) / "ignite_lr_scheduler_cache.pt" - obj = { - "lr_scheduler": lr_scheduler.state_dict(), - "optimizer": lr_scheduler.optimizer.state_dict(), # type: ignore[attr-defined] - } + obj = {"lr_scheduler": lr_scheduler.state_dict(), "optimizer": lr_scheduler.optimizer.state_dict()} torch.save(obj, cache_filepath.as_posix()) values = [] - scheduler = cls(save_history=False, lr_scheduler=lr_scheduler, **kwargs) # type: ignore[call-arg] + scheduler = cls(save_history=False, lr_scheduler=lr_scheduler, **kwargs) for i in range(num_events): params = [p[scheduler.param_name] for p in scheduler.optimizer_param_groups] values.append([i] + params) @@ -758,7 +741,7 @@ def simulate_values( # type: ignore[override] obj = torch.load(cache_filepath.as_posix()) lr_scheduler.load_state_dict(obj["lr_scheduler"]) - lr_scheduler.optimizer.load_state_dict(obj["optimizer"]) # type: ignore[attr-defined] + lr_scheduler.optimizer.load_state_dict(obj["optimizer"]) return values @@ -770,7 +753,7 @@ def create_lr_scheduler_with_warmup( warmup_end_value: Optional[float] = None, save_history: bool = False, output_simulated_values: Optional[List] = None, -) -> "ConcatScheduler": +): """ Helper method to create a learning rate scheduler with a linear warm-up. @@ -826,9 +809,9 @@ def create_lr_scheduler_with_warmup( if not (warmup_duration > 1): raise ValueError("Argument warmup_duration should be at least 2 events, but given {}".format(warmup_duration)) - warmup_schedulers = [] # type: List[ParamScheduler] + warmup_schedulers = [] - for param_group_index, param_group in enumerate(lr_scheduler.optimizer.param_groups): # type: ignore[union-attr] + for param_group_index, param_group in enumerate(lr_scheduler.optimizer.param_groups): if warmup_end_value is None: param_group_warmup_end_value = param_group["lr"] @@ -853,28 +836,21 @@ def create_lr_scheduler_with_warmup( else: milestones_values.pop(-1) - warmup_schedulers.append( - PiecewiseLinear( - lr_scheduler.optimizer, - param_name="lr", - milestones_values=milestones_values, - param_group_index=param_group_index, - save_history=save_history, - ) + warmup_scheduler = PiecewiseLinear( + lr_scheduler.optimizer, + param_name="lr", + milestones_values=milestones_values, + param_group_index=param_group_index, + save_history=save_history, ) + warmup_schedulers.append(warmup_scheduler) + warmup_scheduler = ParamGroupScheduler(warmup_schedulers, save_history=save_history) - schedulers = [ - warmup_scheduler, - lr_scheduler, - ] # type: List[Union[ParamScheduler, ParamGroupScheduler, _LRScheduler]] + schedulers = [warmup_scheduler, lr_scheduler] durations = [milestones_values[-1][0] + 1] - combined_scheduler = ConcatScheduler( - schedulers, # type: ignore[arg-type] - durations=durations, - save_history=save_history, - ) + combined_scheduler = ConcatScheduler(schedulers, durations=durations, save_history=save_history) if output_simulated_values is not None: if not isinstance(output_simulated_values, list): @@ -883,11 +859,7 @@ def create_lr_scheduler_with_warmup( "but given {}.".format(type(output_simulated_values)) ) num_events = len(output_simulated_values) - result = ConcatScheduler.simulate_values( - num_events=num_events, - schedulers=schedulers, # type: ignore[arg-type] - durations=durations, - ) + result = ConcatScheduler.simulate_values(num_events=num_events, schedulers=schedulers, durations=durations) for i in range(num_events): output_simulated_values[i] = result[i] return combined_scheduler @@ -944,10 +916,10 @@ def __init__( "but given {}".format(milestones_values) ) - values = [] # type: List[float] - milestones = [] # type: List[int] + values = [] + milestones = [] for pair in milestones_values: - if not isinstance(pair, tuple) or len(pair) != 2: + if not isinstance(pair, Sequence) or len(pair) != 2: raise ValueError("Argument milestones_values should be a list of pairs (milestone, param_value)") if not isinstance(pair[0], numbers.Integral): raise TypeError("Value of a milestone should be integer, but given {}".format(type(pair[0]))) @@ -964,7 +936,7 @@ def __init__( self._index = 0 self._state_attrs += ["values", "milestones", "_index"] - def _get_start_end(self) -> Tuple[int, int, float, float]: + def _get_start_end(self): if self.milestones[0] > self.event_index: return self.event_index - 1, self.event_index, self.values[0], self.values[0] elif self.milestones[-1] <= self.event_index: @@ -980,7 +952,7 @@ def _get_start_end(self) -> Tuple[int, int, float, float]: self._index += 1 return self._get_start_end() - def get_param(self) -> float: + def get_param(self): start_index, end_index, start_value, end_value = self._get_start_end() return start_value + (end_value - start_value) * (self.event_index - start_index) / (end_index - start_index) @@ -1046,37 +1018,37 @@ def __init__(self, schedulers: List[ParamScheduler], names: Optional[List[str]] self.optimizer = [s.optimizer for s in self.schedulers] self.param_name = [s.param_name for s in self.schedulers] - def __call__(self, engine: Optional[Engine], name: Optional[str] = None) -> None: + def __call__(self, engine: Engine, name: Optional[str] = None): for scheduler, name in zip(self.schedulers, self.names): scheduler(engine, name) @property - def optimizer_param_groups(self) -> List[Dict[str, Any]]: + def optimizer_param_groups(self): return [pg for scheduler in self.schedulers for pg in scheduler.optimizer_param_groups] @property - def save_history(self) -> bool: + def save_history(self): return self.schedulers[0].save_history @save_history.setter - def save_history(self, value: bool) -> None: + def save_history(self, value: bool): for s in self.schedulers: s.save_history = value - def state_dict(self) -> Dict[str, List[Any]]: + def state_dict(self): """Returns a dictionary containing a whole state of ParamGroupScheduler. Returns: dict: a dictionary containing a whole state of ParamGroupScheduler """ - state_dict = OrderedDict() # type: Dict[str, List[Any]] + state_dict = OrderedDict() state_dict["schedulers"] = [] for n, s in zip(self.names, self.schedulers): state_dict["schedulers"].append((n, s.state_dict())) return state_dict - def load_state_dict(self, state_dict: Mapping) -> None: + def load_state_dict(self, state_dict: Mapping): """Copies parameters from :attr:`state_dict` into this ParamScheduler. Args: @@ -1107,7 +1079,7 @@ def load_state_dict(self, state_dict: Mapping) -> None: s.load_state_dict(sd) @classmethod - def simulate_values(cls, num_events: int, schedulers: List[_LRScheduler], **kwargs: Any) -> List[List[int]]: + def simulate_values(cls, num_events: int, schedulers: _LRScheduler, **kwargs: Any): """Method to simulate scheduled values during num_events events. Args: @@ -1126,28 +1098,26 @@ def simulate_values(cls, num_events: int, schedulers: List[_LRScheduler], **kwar cache_filepath = Path(tmpdirname) / "ignite_lr_scheduler_cache.pt" objs = {"lr_scheduler_{}".format(i): s.state_dict() for i, s in enumerate(schedulers)} # all schedulers should be related to the same optimizer - objs["optimizer"] = schedulers[0].optimizer.state_dict() # type: ignore[attr-defined] + objs["optimizer"] = schedulers[0].optimizer.state_dict() torch.save(objs, cache_filepath.as_posix()) values = [] - scheduler = cls(schedulers=schedulers, **kwargs) # type: ignore[arg-type] + scheduler = cls(schedulers=schedulers, **kwargs) for i in range(num_events): - params = [scheduler.get_param() for scheduler in schedulers] # type: ignore[attr-defined] + params = [scheduler.get_param() for scheduler in schedulers] values.append([i] + params) scheduler(engine=None) objs = torch.load(cache_filepath.as_posix()) for i, s in enumerate(schedulers): s.load_state_dict(objs["lr_scheduler_{}".format(i)]) - s.optimizer.load_state_dict(objs["optimizer"]) # type: ignore[attr-defined] + s.optimizer.load_state_dict(objs["optimizer"]) return values -def _get_fake_optimizer( - optimizer_cls: Optional[Union[Type[Optimizer], Type[torch.optim.SGD]]] = None, **kwargs: Any -) -> Union[Optimizer, torch.optim.SGD]: +def _get_fake_optimizer(optimizer_cls: Optional[Optimizer] = None, **kwargs: Any): t = torch.zeros([1], requires_grad=True) if optimizer_cls is None: optimizer_cls = torch.optim.SGD diff --git a/mypy.ini b/mypy.ini index 64a59e66bf43..24487a2a59e3 100644 --- a/mypy.ini +++ b/mypy.ini @@ -24,6 +24,9 @@ warn_unreachable = False warn_unused_configs = True warn_unused_ignores = True +[mypy-ignite.contrib.handlers.param_scheduler] +ignore_errors = True + [mypy-horovod.*] ignore_missing_imports = True From f55800cfba549e89aa0a028a95c12f7ae610368f Mon Sep 17 00:00:00 2001 From: gruebel Date: Thu, 3 Dec 2020 23:42:36 +0100 Subject: [PATCH 3/7] Fix missed mypy error --- ignite/contrib/handlers/base_logger.py | 6 ++++-- ignite/contrib/handlers/tqdm_logger.py | 24 +++++++++++++----------- ignite/engine/events.py | 4 ++-- 3 files changed, 19 insertions(+), 15 deletions(-) diff --git a/ignite/contrib/handlers/base_logger.py b/ignite/contrib/handlers/base_logger.py index 92a6609360ed..9a602c4397a4 100644 --- a/ignite/contrib/handlers/base_logger.py +++ b/ignite/contrib/handlers/base_logger.py @@ -8,7 +8,7 @@ from torch.optim import Optimizer from ignite.engine import Engine, Events, State -from ignite.engine.events import RemovableEventHandle +from ignite.engine.events import CallableEventWithFilter, RemovableEventHandle class BaseHandler(metaclass=ABCMeta): @@ -148,7 +148,9 @@ class BaseLogger(metaclass=ABCMeta): """ - def attach(self, engine: Engine, log_handler: Callable, event_name: Union[str, Events]) -> RemovableEventHandle: + def attach( + self, engine: Engine, log_handler: Callable, event_name: Union[str, Events, CallableEventWithFilter] + ) -> RemovableEventHandle: """Attach the logger to the engine and execute `log_handler` function at `event_name` events. Args: diff --git a/ignite/contrib/handlers/tqdm_logger.py b/ignite/contrib/handlers/tqdm_logger.py index 0c5e3e091e22..6dc7d22688c0 100644 --- a/ignite/contrib/handlers/tqdm_logger.py +++ b/ignite/contrib/handlers/tqdm_logger.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- import warnings -from typing import Any, Callable, List, Optional, Union, cast +from typing import Any, Callable, List, Optional, Union import torch @@ -98,7 +98,7 @@ class ProgressBar(BaseLogger): Events.ITERATION_COMPLETED, Events.EPOCH_COMPLETED, Events.COMPLETED, - ] + ] # type: List[Union[Events, CallableEventWithFilter]] def __init__( self, @@ -121,7 +121,7 @@ def __init__( self.bar_format = bar_format self.tqdm_kwargs = tqdm_kwargs - def _reset(self, pbar_total: int) -> None: + def _reset(self, pbar_total: Optional[int]) -> None: self.pbar = self.pbar_cls( total=pbar_total, leave=self.persist, bar_format=self.bar_format, initial=1, **self.tqdm_kwargs ) @@ -137,7 +137,9 @@ def _close(self, engine: Engine) -> None: self.pbar = None @staticmethod - def _compare_lt(event1: Events, event2: Events) -> bool: + def _compare_lt( + event1: Union[Events, CallableEventWithFilter], event2: Union[Events, CallableEventWithFilter] + ) -> bool: i1 = ProgressBar._events_order.index(event1) i2 = ProgressBar._events_order.index(event2) return i1 < i2 @@ -158,8 +160,8 @@ def attach( # type: ignore[override] engine: Engine, metric_names: Optional[str] = None, output_transform: Optional[Callable] = None, - event_name: Events = Events.ITERATION_COMPLETED, - closing_event_name: Events = Events.EPOCH_COMPLETED, + event_name: Union[Events, CallableEventWithFilter] = Events.ITERATION_COMPLETED, + closing_event_name: Union[Events, CallableEventWithFilter] = Events.EPOCH_COMPLETED, ) -> None: """ Attaches the progress bar to an engine object. @@ -235,7 +237,7 @@ def __init__( description: str, metric_names: Optional[Union[str, List[str]]] = None, output_transform: Optional[Callable] = None, - closing_event_name: Events = Events.EPOCH_COMPLETED, + closing_event_name: Union[Events, CallableEventWithFilter] = Events.EPOCH_COMPLETED, ) -> None: if metric_names is None and output_transform is None: # This helps to avoid 'Either metric_names or output_transform should be defined' of BaseOutputHandler @@ -244,11 +246,11 @@ def __init__( self.closing_event_name = closing_event_name @staticmethod - def get_max_number_events(event_name: Union[str, Events], engine: Engine) -> int: + def get_max_number_events(event_name: Union[str, Events, CallableEventWithFilter], engine: Engine) -> Optional[int]: if event_name in (Events.ITERATION_STARTED, Events.ITERATION_COMPLETED): - return cast(int, engine.state.epoch_length) + return engine.state.epoch_length if event_name in (Events.EPOCH_STARTED, Events.EPOCH_COMPLETED): - return cast(int, engine.state.max_epochs) + return engine.state.max_epochs return 1 def __call__(self, engine: Engine, logger: ProgressBar, event_name: Union[str, Events]) -> None: @@ -262,7 +264,7 @@ def __call__(self, engine: Engine, logger: ProgressBar, event_name: Union[str, E desc = self.tag or default_desc max_num_of_closing_events = self.get_max_number_events(self.closing_event_name, engine) - if max_num_of_closing_events > 1: + if max_num_of_closing_events and max_num_of_closing_events > 1: global_step = engine.state.get_event_attrib_value(self.closing_event_name) desc += " [{}/{}]".format(global_step, max_num_of_closing_events) logger.pbar.set_description(desc) # type: ignore[attr-defined] diff --git a/ignite/engine/events.py b/ignite/engine/events.py index e7782a2646cd..e05f477b2a82 100644 --- a/ignite/engine/events.py +++ b/ignite/engine/events.py @@ -362,7 +362,7 @@ class State: Events.EPOCH_COMPLETED: "epoch", Events.STARTED: "epoch", Events.COMPLETED: "epoch", - } # type: Dict[Union[str, "Events"], str] + } # type: Dict[Union[str, "Events", "CallableEventWithFilter"], str] def __init__(self, **kwargs: Any) -> None: self.iteration = 0 @@ -390,7 +390,7 @@ def _update_attrs(self) -> None: if not hasattr(self, value): setattr(self, value, 0) - def get_event_attrib_value(self, event_name: Union[str, Events]) -> int: + def get_event_attrib_value(self, event_name: Union[str, Events, CallableEventWithFilter]) -> int: if event_name not in State.event_to_attr: raise RuntimeError("Unknown event name '{}'".format(event_name)) return getattr(self, State.event_to_attr[event_name]) From 4aecd51badd430fb5c654b9874353bbe4788f8f8 Mon Sep 17 00:00:00 2001 From: vfdev Date: Sat, 5 Dec 2020 09:15:23 +0100 Subject: [PATCH 4/7] Update time_profilers.py --- ignite/contrib/handlers/time_profilers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ignite/contrib/handlers/time_profilers.py b/ignite/contrib/handlers/time_profilers.py index 1ed191546a43..f4062eeb049f 100644 --- a/ignite/contrib/handlers/time_profilers.py +++ b/ignite/contrib/handlers/time_profilers.py @@ -588,7 +588,7 @@ def get_results(self) -> List[List[Union[str, float]]]: def compute_basic_stats( times: Union[Sequence, torch.Tensor] ) -> List[Union[str, float, Tuple[Union[str, float], Union[str, float]]]]: - data = torch.as_tensor(data, dtype=torch.float32) + data = torch.as_tensor(times, dtype=torch.float32) # compute on non-zero data: data = data[data > 0] total = round(torch.sum(data).item(), 5) if len(data) > 0 else "not triggered" # type: Union[str, float] From 0476912ba7101ca141a7611879109312bbfb5c8d Mon Sep 17 00:00:00 2001 From: vfdev Date: Sat, 5 Dec 2020 09:42:06 +0100 Subject: [PATCH 5/7] Update time_profilers.py --- ignite/contrib/handlers/time_profilers.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/ignite/contrib/handlers/time_profilers.py b/ignite/contrib/handlers/time_profilers.py index f4062eeb049f..f3708dad15dc 100644 --- a/ignite/contrib/handlers/time_profilers.py +++ b/ignite/contrib/handlers/time_profilers.py @@ -97,6 +97,11 @@ def _as_first_started(self, engine: Engine) -> None: if hasattr(engine.state.dataloader, "__len__"): num_iters_per_epoch = len(engine.state.dataloader) # type: ignore[arg-type] else: + if engine.state.epoch_length is None: + raise ValueError( + "As epoch_length is not set, we can not use BasicTimeProfiler in this case." + "Please, set trainer.run(..., epoch_length=epoch_length) in order to fix this." + ) num_iters_per_epoch = cast(int, engine.state.epoch_length) self.max_epochs = cast(int, engine.state.max_epochs) From 262e3653f2bb5119f89d5f992ec64ae0f6a68e33 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Sat, 5 Dec 2020 08:43:21 +0000 Subject: [PATCH 6/7] autopep8 fix --- ignite/contrib/handlers/time_profilers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ignite/contrib/handlers/time_profilers.py b/ignite/contrib/handlers/time_profilers.py index f3708dad15dc..383a102def50 100644 --- a/ignite/contrib/handlers/time_profilers.py +++ b/ignite/contrib/handlers/time_profilers.py @@ -101,7 +101,7 @@ def _as_first_started(self, engine: Engine) -> None: raise ValueError( "As epoch_length is not set, we can not use BasicTimeProfiler in this case." "Please, set trainer.run(..., epoch_length=epoch_length) in order to fix this." - ) + ) num_iters_per_epoch = cast(int, engine.state.epoch_length) self.max_epochs = cast(int, engine.state.max_epochs) From c1e71d35206da6cdea2cbbf7f55c0eda5e8bbdcb Mon Sep 17 00:00:00 2001 From: vfdev Date: Sat, 5 Dec 2020 10:26:28 +0100 Subject: [PATCH 7/7] Update time_profilers.py --- ignite/contrib/handlers/time_profilers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ignite/contrib/handlers/time_profilers.py b/ignite/contrib/handlers/time_profilers.py index 383a102def50..95e03f55a967 100644 --- a/ignite/contrib/handlers/time_profilers.py +++ b/ignite/contrib/handlers/time_profilers.py @@ -102,7 +102,7 @@ def _as_first_started(self, engine: Engine) -> None: "As epoch_length is not set, we can not use BasicTimeProfiler in this case." "Please, set trainer.run(..., epoch_length=epoch_length) in order to fix this." ) - num_iters_per_epoch = cast(int, engine.state.epoch_length) + num_iters_per_epoch = engine.state.epoch_length self.max_epochs = cast(int, engine.state.max_epochs) self.total_num_iters = self.max_epochs * num_iters_per_epoch