diff --git a/ignite/contrib/engines/common.py b/ignite/contrib/engines/common.py index 88c637545c7a..78228537c636 100644 --- a/ignite/contrib/engines/common.py +++ b/ignite/contrib/engines/common.py @@ -1,7 +1,7 @@ import numbers import warnings from functools import partial -from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Union +from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Tuple, Union, cast import torch import torch.nn as nn @@ -47,7 +47,7 @@ def setup_common_training_handlers( clear_cuda_cache: bool = True, save_handler: Optional[Union[Callable, BaseSaveHandler]] = None, **kwargs: Any -): +) -> None: """Helper method to setup trainer with common handlers (it also supports distributed configuration): - :class:`~ignite.handlers.TerminateOnNan` @@ -88,24 +88,24 @@ class to use to store ``to_save``. See :class:`~ignite.handlers.checkpoint.Check **kwargs: optional keyword args to be passed to construct :class:`~ignite.handlers.checkpoint.Checkpoint`. """ - _kwargs = dict( - to_save=to_save, - save_every_iters=save_every_iters, - output_path=output_path, - lr_scheduler=lr_scheduler, - with_gpu_stats=with_gpu_stats, - output_names=output_names, - with_pbars=with_pbars, - with_pbar_on_iters=with_pbar_on_iters, - log_every_iters=log_every_iters, - stop_on_nan=stop_on_nan, - clear_cuda_cache=clear_cuda_cache, - save_handler=save_handler, - ) - _kwargs.update(kwargs) - if idist.get_world_size() > 1: - _setup_common_distrib_training_handlers(trainer, train_sampler=train_sampler, **_kwargs) + _setup_common_distrib_training_handlers( + trainer, + train_sampler=train_sampler, + to_save=to_save, + save_every_iters=save_every_iters, + output_path=output_path, + lr_scheduler=lr_scheduler, + with_gpu_stats=with_gpu_stats, + output_names=output_names, + with_pbars=with_pbars, + with_pbar_on_iters=with_pbar_on_iters, + log_every_iters=log_every_iters, + stop_on_nan=stop_on_nan, + clear_cuda_cache=clear_cuda_cache, + save_handler=save_handler, + **kwargs, + ) else: if train_sampler is not None and isinstance(train_sampler, DistributedSampler): warnings.warn( @@ -114,7 +114,22 @@ class to use to store ``to_save``. See :class:`~ignite.handlers.checkpoint.Check "Train sampler argument will be ignored", UserWarning, ) - _setup_common_training_handlers(trainer, **_kwargs) + _setup_common_training_handlers( + trainer, + to_save=to_save, + save_every_iters=save_every_iters, + output_path=output_path, + lr_scheduler=lr_scheduler, + with_gpu_stats=with_gpu_stats, + output_names=output_names, + with_pbars=with_pbars, + with_pbar_on_iters=with_pbar_on_iters, + log_every_iters=log_every_iters, + stop_on_nan=stop_on_nan, + clear_cuda_cache=clear_cuda_cache, + save_handler=save_handler, + **kwargs, + ) setup_common_distrib_training_handlers = setup_common_training_handlers @@ -135,7 +150,7 @@ def _setup_common_training_handlers( clear_cuda_cache: bool = True, save_handler: Optional[Union[Callable, BaseSaveHandler]] = None, **kwargs: Any -): +) -> None: if output_path is not None and save_handler is not None: raise ValueError( "Arguments output_path and save_handler are mutually exclusive. Please, define only one of them" @@ -146,7 +161,9 @@ def _setup_common_training_handlers( if lr_scheduler is not None: if isinstance(lr_scheduler, torch.optim.lr_scheduler._LRScheduler): - trainer.add_event_handler(Events.ITERATION_COMPLETED, lambda engine: lr_scheduler.step()) + trainer.add_event_handler( + Events.ITERATION_COMPLETED, lambda engine: cast(_LRScheduler, lr_scheduler).step() + ) elif isinstance(lr_scheduler, LRScheduler): trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler) else: @@ -164,15 +181,19 @@ def _setup_common_training_handlers( if output_path is not None: save_handler = DiskSaver(dirname=output_path, require_empty=False) - checkpoint_handler = Checkpoint(to_save, save_handler, filename_prefix="training", **kwargs) + checkpoint_handler = Checkpoint( + to_save, cast(Union[Callable, BaseSaveHandler], save_handler), filename_prefix="training", **kwargs + ) trainer.add_event_handler(Events.ITERATION_COMPLETED(every=save_every_iters), checkpoint_handler) if with_gpu_stats: - GpuInfo().attach(trainer, name="gpu", event_name=Events.ITERATION_COMPLETED(every=log_every_iters)) + GpuInfo().attach( + trainer, name="gpu", event_name=Events.ITERATION_COMPLETED(every=log_every_iters) # type: ignore[arg-type] + ) if output_names is not None: - def output_transform(x, index, name): + def output_transform(x: Any, index: int, name: str) -> Any: if isinstance(x, Mapping): return x[name] elif isinstance(x, Sequence): @@ -217,7 +238,7 @@ def _setup_common_distrib_training_handlers( clear_cuda_cache: bool = True, save_handler: Optional[Union[Callable, BaseSaveHandler]] = None, **kwargs: Any -): +) -> None: _setup_common_training_handlers( trainer, @@ -241,18 +262,18 @@ def _setup_common_distrib_training_handlers( raise TypeError("Train sampler should be torch DistributedSampler and have `set_epoch` method") @trainer.on(Events.EPOCH_STARTED) - def distrib_set_epoch(engine): - train_sampler.set_epoch(engine.state.epoch - 1) + def distrib_set_epoch(engine: Engine) -> None: + cast(DistributedSampler, train_sampler).set_epoch(engine.state.epoch - 1) -def empty_cuda_cache(_): +def empty_cuda_cache(_: Engine) -> None: torch.cuda.empty_cache() import gc gc.collect() -def setup_any_logging(logger, logger_module, trainer, optimizers, evaluators, log_every_iters): +def setup_any_logging(logger, logger_module, trainer, optimizers, evaluators, log_every_iters) -> None: # type: ignore raise DeprecationWarning( "ignite.contrib.engines.common.setup_any_logging is deprecated since 0.4.0. and will be remove in 0.6.0. " "Please use instead: setup_tb_logging, setup_visdom_logging or setup_mlflow_logging etc." @@ -262,10 +283,10 @@ def setup_any_logging(logger, logger_module, trainer, optimizers, evaluators, lo def _setup_logging( logger: BaseLogger, trainer: Engine, - optimizers: Union[Optimizer, Dict[str, Optimizer]], - evaluators: Union[Engine, Dict[str, Engine]], + optimizers: Optional[Union[Optimizer, Dict[str, Optimizer], Dict[None, Optimizer]]], + evaluators: Optional[Union[Engine, Dict[str, Engine]]], log_every_iters: int, -): +) -> None: if optimizers is not None: if not isinstance(optimizers, (Optimizer, Mapping)): raise TypeError("Argument optimizers should be either a single optimizer or a dictionary or optimizers") @@ -311,7 +332,7 @@ def setup_tb_logging( evaluators: Optional[Union[Engine, Dict[str, Engine]]] = None, log_every_iters: int = 100, **kwargs: Any -): +) -> TensorboardLogger: """Method to setup TensorBoard logging on trainer and a list of evaluators. Logged metrics are: - Training metrics, e.g. running average loss values @@ -343,7 +364,7 @@ def setup_visdom_logging( evaluators: Optional[Union[Engine, Dict[str, Engine]]] = None, log_every_iters: int = 100, **kwargs: Any -): +) -> VisdomLogger: """Method to setup Visdom logging on trainer and a list of evaluators. Logged metrics are: - Training metrics, e.g. running average loss values @@ -374,7 +395,7 @@ def setup_mlflow_logging( evaluators: Optional[Union[Engine, Dict[str, Engine]]] = None, log_every_iters: int = 100, **kwargs: Any -): +) -> MLflowLogger: """Method to setup MLflow logging on trainer and a list of evaluators. Logged metrics are: - Training metrics, e.g. running average loss values @@ -405,7 +426,7 @@ def setup_neptune_logging( evaluators: Optional[Union[Engine, Dict[str, Engine]]] = None, log_every_iters: int = 100, **kwargs: Any -): +) -> NeptuneLogger: """Method to setup Neptune logging on trainer and a list of evaluators. Logged metrics are: - Training metrics, e.g. running average loss values @@ -436,7 +457,7 @@ def setup_wandb_logging( evaluators: Optional[Union[Engine, Dict[str, Engine]]] = None, log_every_iters: int = 100, **kwargs: Any -): +) -> WandBLogger: """Method to setup WandB logging on trainer and a list of evaluators. Logged metrics are: - Training metrics, e.g. running average loss values @@ -467,7 +488,7 @@ def setup_plx_logging( evaluators: Optional[Union[Engine, Dict[str, Engine]]] = None, log_every_iters: int = 100, **kwargs: Any -): +) -> PolyaxonLogger: """Method to setup Polyaxon logging on trainer and a list of evaluators. Logged metrics are: - Training metrics, e.g. running average loss values @@ -498,7 +519,7 @@ def setup_trains_logging( evaluators: Optional[Union[Engine, Dict[str, Engine]]] = None, log_every_iters: int = 100, **kwargs: Any -): +) -> TrainsLogger: """Method to setup Trains logging on trainer and a list of evaluators. Logged metrics are: - Training metrics, e.g. running average loss values @@ -523,8 +544,8 @@ def setup_trains_logging( return logger -def get_default_score_fn(metric_name: str): - def wrapper(engine: Engine): +def get_default_score_fn(metric_name: str) -> Any: + def wrapper(engine: Engine) -> Any: score = engine.state.metrics[metric_name] return score @@ -540,7 +561,7 @@ def gen_save_best_models_by_val_score( trainer: Optional[Engine] = None, tag: str = "val", **kwargs: Any -): +) -> Checkpoint: """Method adds a handler to ``evaluator`` to save ``n_saved`` of best models based on the metric (named by ``metric_name``) provided by ``evaluator`` (i.e. ``evaluator.state.metrics[metric_name]``). Models with highest metric value will be retained. The logic of how to store objects is delegated to @@ -570,9 +591,10 @@ def gen_save_best_models_by_val_score( if trainer is not None: global_step_transform = global_step_from_engine(trainer) - to_save = models if isinstance(models, nn.Module): - to_save = {"model": models} + to_save = {"model": models} # type: Dict[str, nn.Module] + else: + to_save = models best_model_handler = Checkpoint( to_save, @@ -598,7 +620,7 @@ def save_best_model_by_val_score( trainer: Optional[Engine] = None, tag: str = "val", **kwargs: Any -): +) -> Checkpoint: """Method adds a handler to ``evaluator`` to save on a disk ``n_saved`` of best models based on the metric (named by ``metric_name``) provided by ``evaluator`` (i.e. ``evaluator.state.metrics[metric_name]``). Models with highest metric value will be retained. @@ -629,7 +651,9 @@ def save_best_model_by_val_score( ) -def add_early_stopping_by_val_score(patience: int, evaluator: Engine, trainer: Engine, metric_name: str): +def add_early_stopping_by_val_score( + patience: int, evaluator: Engine, trainer: Engine, metric_name: str +) -> EarlyStopping: """Method setups early stopping handler based on the score (named by `metric_name`) provided by `evaluator`. Metric value should increase in order to keep training and not early stop. diff --git a/ignite/contrib/engines/tbptt.py b/ignite/contrib/engines/tbptt.py index d6ef72b26e43..d3efba4accbf 100644 --- a/ignite/contrib/engines/tbptt.py +++ b/ignite/contrib/engines/tbptt.py @@ -1,4 +1,5 @@ # coding: utf-8 +import collections.abc as collections from typing import Callable, Mapping, Optional, Sequence, Union import torch @@ -20,7 +21,9 @@ class Tbptt_Events(EventEnum): TIME_ITERATION_COMPLETED = "time_iteration_completed" -def _detach_hidden(hidden: Union[torch.Tensor, Sequence, Mapping, str, bytes]): +def _detach_hidden( + hidden: Union[torch.Tensor, Sequence, Mapping, str, bytes] +) -> Union[torch.Tensor, collections.Sequence, collections.Mapping, str, bytes]: """Cut backpropagation graph. Auxillary function to cut the backpropagation graph by detaching the hidden @@ -38,7 +41,7 @@ def create_supervised_tbptt_trainer( device: Optional[str] = None, non_blocking: bool = False, prepare_batch: Callable = _prepare_batch, -): +) -> Engine: """Create a trainer for truncated backprop through time supervised models. Training recurrent model on long sequences is computationally intensive as @@ -83,7 +86,7 @@ def create_supervised_tbptt_trainer( """ - def _update(engine: Engine, batch: Sequence[torch.Tensor]): + def _update(engine: Engine, batch: Sequence[torch.Tensor]) -> float: loss_list = [] hidden = None diff --git a/ignite/contrib/handlers/tqdm_logger.py b/ignite/contrib/handlers/tqdm_logger.py index 9e7db89f666f..a0fe33f561e6 100644 --- a/ignite/contrib/handlers/tqdm_logger.py +++ b/ignite/contrib/handlers/tqdm_logger.py @@ -159,7 +159,7 @@ def attach( engine: Engine, metric_names: Optional[str] = None, output_transform: Optional[Callable] = None, - event_name: Events = Events.ITERATION_COMPLETED, + event_name: Union[CallableEventWithFilter, Events] = Events.ITERATION_COMPLETED, closing_event_name: Events = Events.EPOCH_COMPLETED, ): """ diff --git a/mypy.ini b/mypy.ini index bd790c5ecc8c..62b36053b343 100644 --- a/mypy.ini +++ b/mypy.ini @@ -28,10 +28,6 @@ warn_unused_ignores = True ignore_errors = True -[mypy-ignite.contrib.engines.*] - -ignore_errors = True - [mypy-horovod.*] ignore_missing_imports = True