Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 18 additions & 15 deletions ignite/contrib/handlers/base_logger.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
import numbers
import warnings
from abc import ABCMeta, abstractmethod
from typing import Any, Callable, List, Optional, Sequence, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Union

import torch
import torch.nn as nn
from torch.optim import Optimizer

from ignite.engine import Engine, State
from ignite.engine import Engine, Events, State
from ignite.engine.events import CallableEventWithFilter, RemovableEventHandle


class BaseHandler(metaclass=ABCMeta):
@abstractmethod
def __call__(self, engine, logger, event_name):
def __call__(self, engine: Engine, logger: Any, event_name: Union[str, Events]) -> None:
pass


Expand Down Expand Up @@ -68,15 +69,15 @@ def __init__(

if global_step_transform is None:

def global_step_transform(engine, event_name):
def global_step_transform(engine: Engine, event_name: Union[str, Events]) -> int:
return engine.state.get_event_attrib_value(event_name)

self.tag = tag
self.metric_names = metric_names
self.output_transform = output_transform
self.global_step_transform = global_step_transform

def _setup_output_metrics(self, engine: Engine):
def _setup_output_metrics(self, engine: Engine) -> Dict[str, Any]:
"""Helper method to setup metrics to log
"""
metrics = {}
Expand Down Expand Up @@ -108,14 +109,14 @@ class BaseWeightsScalarHandler(BaseHandler):
Helper handler to log model's weights as scalars.
"""

def __init__(self, model: nn.Module, reduction: Callable = torch.norm, tag: Optional[str] = None):
def __init__(self, model: nn.Module, reduction: Callable = torch.norm, tag: Optional[str] = None) -> None:
if not isinstance(model, torch.nn.Module):
raise TypeError("Argument model should be of type torch.nn.Module, " "but given {}".format(type(model)))

if not callable(reduction):
raise TypeError("Argument reduction should be callable, " "but given {}".format(type(reduction)))

def _is_0D_tensor(t: torch.Tensor):
def _is_0D_tensor(t: torch.Tensor) -> bool:
return isinstance(t, torch.Tensor) and t.ndimension() == 0

# Test reduction function on a tensor
Expand Down Expand Up @@ -147,7 +148,9 @@ class BaseLogger(metaclass=ABCMeta):

"""

def attach(self, engine: Engine, log_handler: Callable, event_name: Any):
def attach(
self, engine: Engine, log_handler: Callable, event_name: Union[str, Events, CallableEventWithFilter]
) -> RemovableEventHandle:
"""Attach the logger to the engine and execute `log_handler` function at `event_name` events.

Args:
Expand All @@ -167,7 +170,7 @@ def attach(self, engine: Engine, log_handler: Callable, event_name: Any):

return engine.add_event_handler(event_name, log_handler, self, name)

def attach_output_handler(self, engine: Engine, event_name: Any, *args: Any, **kwargs: Any):
def attach_output_handler(self, engine: Engine, event_name: Any, *args: Any, **kwargs: Any) -> RemovableEventHandle:
"""Shortcut method to attach `OutputHandler` to the logger.

Args:
Expand All @@ -183,7 +186,7 @@ def attach_output_handler(self, engine: Engine, event_name: Any, *args: Any, **k
"""
return self.attach(engine, self._create_output_handler(*args, **kwargs), event_name=event_name)

def attach_opt_params_handler(self, engine: Engine, event_name: Any, *args: Any, **kwargs: Any):
def attach_opt_params_handler(self, engine: Engine, event_name: Any, *args: Any, **kwargs: Any) -> None:
"""Shortcut method to attach `OptimizerParamsHandler` to the logger.

Args:
Expand All @@ -200,18 +203,18 @@ def attach_opt_params_handler(self, engine: Engine, event_name: Any, *args: Any,
self.attach(engine, self._create_opt_params_handler(*args, **kwargs), event_name=event_name)

@abstractmethod
def _create_output_handler(self, engine: Engine, *args: Any, **kwargs: Any):
def _create_output_handler(self, engine: Engine, *args: Any, **kwargs: Any) -> Callable:
pass

@abstractmethod
def _create_opt_params_handler(self, *args: Any, **kwargs: Any):
def _create_opt_params_handler(self, *args: Any, **kwargs: Any) -> Callable:
pass

def __enter__(self):
def __enter__(self) -> "BaseLogger":
return self

def __exit__(self, type, value, traceback):
def __exit__(self, type: Any, value: Any, traceback: Any) -> None:
self.close()

def close(self):
def close(self) -> None:
pass
48 changes: 24 additions & 24 deletions ignite/contrib/handlers/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import tempfile
import warnings
from pathlib import Path
from typing import Callable, Mapping, Optional
from typing import Any, Callable, Dict, List, Mapping, Optional, Union

import torch
from torch.optim import Optimizer
Expand Down Expand Up @@ -71,11 +71,11 @@ class FastaiLRFinder:
fastai/lr_find: https://github.com/fastai/fastai
"""

def __init__(self):
def __init__(self) -> None:
self._diverge_flag = False
self._history = None
self._history = {} # type: Dict[str, List[Any]]
self._best_loss = None
self._lr_schedule = None
self._lr_schedule = None # type: Optional[Union[LRScheduler, PiecewiseLinear]]
self.logger = logging.getLogger(__name__)

def _run(
Expand All @@ -88,7 +88,7 @@ def _run(
step_mode: str,
smooth_f: float,
diverge_th: float,
):
) -> None:

self._history = {"lr": [], "loss": []}
self._best_loss = None
Expand All @@ -98,7 +98,7 @@ def _run(
if num_iter is None:
num_iter = trainer.state.epoch_length * trainer.state.max_epochs
else:
max_iter = trainer.state.epoch_length * trainer.state.max_epochs
max_iter = trainer.state.epoch_length * trainer.state.max_epochs # type: ignore[operator]
if num_iter > max_iter:
warnings.warn(
"Desired num_iter {} is unreachable with the current run setup of {} iteration "
Expand Down Expand Up @@ -127,16 +127,16 @@ def _run(
if not trainer.has_event_handler(self._lr_schedule):
trainer.add_event_handler(Events.ITERATION_COMPLETED, self._lr_schedule, num_iter)

def _reset(self, trainer: Engine):
def _reset(self, trainer: Engine) -> None:
self.logger.debug("Completed LR finder run")
trainer.remove_event_handler(self._lr_schedule, Events.ITERATION_COMPLETED)
trainer.remove_event_handler(self._lr_schedule, Events.ITERATION_COMPLETED) # type: ignore[arg-type]
trainer.remove_event_handler(self._log_lr_and_loss, Events.ITERATION_COMPLETED)
trainer.remove_event_handler(self._reached_num_iterations, Events.ITERATION_COMPLETED)

def _log_lr_and_loss(self, trainer: Engine, output_transform: Callable, smooth_f: float, diverge_th: float):
def _log_lr_and_loss(self, trainer: Engine, output_transform: Callable, smooth_f: float, diverge_th: float) -> None:
output = trainer.state.output
loss = output_transform(output)
lr = self._lr_schedule.get_param()
lr = self._lr_schedule.get_param() # type: ignore[union-attr]
self._history["lr"].append(lr)
if trainer.state.iteration == 1:
self._best_loss = loss
Expand All @@ -148,24 +148,24 @@ def _log_lr_and_loss(self, trainer: Engine, output_transform: Callable, smooth_f
self._history["loss"].append(loss)

# Check if the loss has diverged; if it has, stop the trainer
if self._history["loss"][-1] > diverge_th * self._best_loss:
if self._history["loss"][-1] > diverge_th * self._best_loss: # type: ignore[operator]
self._diverge_flag = True
self.logger.info("Stopping early, the loss has diverged")
trainer.terminate()

def _reached_num_iterations(self, trainer: Engine, num_iter: int):
def _reached_num_iterations(self, trainer: Engine, num_iter: int) -> None:
if trainer.state.iteration > num_iter:
trainer.terminate()

def _warning(self, _):
def _warning(self, _: Any) -> None:
if not self._diverge_flag:
warnings.warn(
"Run completed without loss diverging, increase end_lr, decrease diverge_th or look"
" at lr_finder.plot()",
UserWarning,
)

def _detach(self, trainer: Engine):
def _detach(self, trainer: Engine) -> None:
"""
Detaches lr_finder from trainer.

Expand All @@ -180,13 +180,13 @@ def _detach(self, trainer: Engine):
if trainer.has_event_handler(self._reset, Events.COMPLETED):
trainer.remove_event_handler(self._reset, Events.COMPLETED)

def get_results(self):
def get_results(self) -> Dict[str, List[Any]]:
"""
Returns: dictionary with loss and lr logs fromm the previous run
"""
return self._history

def plot(self, skip_start: int = 10, skip_end: int = 5, log_lr: bool = True):
def plot(self, skip_start: int = 10, skip_end: int = 5, log_lr: bool = True) -> None:
"""Plots the learning rate range test.

This method requires `matplotlib` package to be installed:
Expand All @@ -211,7 +211,7 @@ def plot(self, skip_start: int = 10, skip_end: int = 5, log_lr: bool = True):
"Please install it with command: \n pip install matplotlib"
)

if self._history is None:
if not self._history:
raise RuntimeError("learning rate finder didn't run yet so results can't be plotted")

if skip_start < 0:
Expand Down Expand Up @@ -239,11 +239,11 @@ def plot(self, skip_start: int = 10, skip_end: int = 5, log_lr: bool = True):
plt.ylabel("Loss")
plt.show()

def lr_suggestion(self):
def lr_suggestion(self) -> Any:
"""
Returns: learning rate at the minimum numerical gradient
"""
if self._history is None:
if not self._history:
raise RuntimeError("learning rate finder didn't run yet so lr_suggestion can't be returned")
loss = self._history["loss"]
grads = torch.tensor([loss[i] - loss[i - 1] for i in range(1, len(loss))])
Expand All @@ -261,7 +261,7 @@ def attach(
step_mode: str = "exp",
smooth_f: float = 0.05,
diverge_th: float = 5.0,
):
) -> Any:
"""Attaches lr_finder to a given trainer. It also resets model and optimizer at the end of the run.

Usage:
Expand Down Expand Up @@ -372,12 +372,12 @@ class _ExponentialLR(_LRScheduler):

"""

def __init__(self, optimizer: Optimizer, end_lr: float, num_iter: int, last_epoch: int = -1):
def __init__(self, optimizer: Optimizer, end_lr: float, num_iter: int, last_epoch: int = -1) -> None:
self.end_lr = end_lr
self.num_iter = num_iter
super(_ExponentialLR, self).__init__(optimizer, last_epoch)

def get_lr(self):
curr_iter = self.last_epoch + 1
def get_lr(self) -> List[float]: # type: ignore
curr_iter = self.last_epoch + 1 # type: ignore[attr-defined]
r = curr_iter / self.num_iter
return [base_lr * (self.end_lr / base_lr) ** r for base_lr in self.base_lrs]
return [base_lr * (self.end_lr / base_lr) ** r for base_lr in self.base_lrs] # type: ignore[attr-defined]
28 changes: 14 additions & 14 deletions ignite/contrib/handlers/mlflow_logger.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import numbers
import warnings
from typing import Any, Callable, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union

import torch
from torch.optim import Optimizer

from ignite.contrib.handlers.base_logger import BaseLogger, BaseOptimizerParamsHandler, BaseOutputHandler
from ignite.engine import Engine, EventEnum
from ignite.engine import Engine, Events
from ignite.handlers import global_step_from_engine

__all__ = ["MLflowLogger", "OutputHandler", "OptimizerParamsHandler", "global_step_from_engine"]
Expand Down Expand Up @@ -86,7 +86,7 @@ class MLflowLogger(BaseLogger):
)
"""

def __init__(self, tracking_uri: Optional[str] = None):
def __init__(self, tracking_uri: Optional[str] = None) -> None:
try:
import mlflow
except ImportError:
Expand All @@ -102,21 +102,21 @@ def __init__(self, tracking_uri: Optional[str] = None):
if self.active_run is None:
self.active_run = mlflow.start_run()

def __getattr__(self, attr: Any):
def __getattr__(self, attr: Any) -> Any:

import mlflow

return getattr(mlflow, attr)

def close(self):
def close(self) -> None:
import mlflow

mlflow.end_run()

def _create_output_handler(self, *args: Any, **kwargs: Any):
def _create_output_handler(self, *args: Any, **kwargs: Any) -> "OutputHandler":
return OutputHandler(*args, **kwargs)

def _create_opt_params_handler(self, *args: Any, **kwargs: Any):
def _create_opt_params_handler(self, *args: Any, **kwargs: Any) -> "OptimizerParamsHandler":
return OptimizerParamsHandler(*args, **kwargs)


Expand Down Expand Up @@ -212,28 +212,28 @@ def __init__(
metric_names: Optional[Union[str, List[str]]] = None,
output_transform: Optional[Callable] = None,
global_step_transform: Optional[Callable] = None,
):
) -> None:
super(OutputHandler, self).__init__(tag, metric_names, output_transform, global_step_transform)

def __call__(self, engine: Engine, logger: MLflowLogger, event_name: Union[str, EventEnum]):
def __call__(self, engine: Engine, logger: MLflowLogger, event_name: Union[str, Events]) -> None:

if not isinstance(logger, MLflowLogger):
raise TypeError("Handler 'OutputHandler' works only with MLflowLogger")

metrics = self._setup_output_metrics(engine)

global_step = self.global_step_transform(engine, event_name)
global_step = self.global_step_transform(engine, event_name) # type: ignore[misc]

if not isinstance(global_step, int):
raise TypeError(
"global_step must be int, got {}."
" Please check the output of global_step_transform.".format(type(global_step))
)

rendered_metrics = {}
rendered_metrics = {} # type: Dict[str, float]
for key, value in metrics.items():
if isinstance(value, numbers.Number):
rendered_metrics["{} {}".format(self.tag, key)] = value
rendered_metrics["{} {}".format(self.tag, key)] = value # type: ignore[assignment]
elif isinstance(value, torch.Tensor) and value.ndimension() == 0:
rendered_metrics["{} {}".format(self.tag, key)] = value.item()
elif isinstance(value, torch.Tensor) and value.ndimension() == 1:
Expand Down Expand Up @@ -290,10 +290,10 @@ class OptimizerParamsHandler(BaseOptimizerParamsHandler):
tag (str, optional): common title for all produced plots. For example, 'generator'
"""

def __init__(self, optimizer: Optimizer, param_name: str = "lr", tag: Optional[str] = None):
def __init__(self, optimizer: Optimizer, param_name: str = "lr", tag: Optional[str] = None) -> None:
super(OptimizerParamsHandler, self).__init__(optimizer, param_name, tag)

def __call__(self, engine: Engine, logger: MLflowLogger, event_name: Union[str, EventEnum]):
def __call__(self, engine: Engine, logger: MLflowLogger, event_name: Union[str, Events]) -> None:
if not isinstance(logger, MLflowLogger):
raise TypeError("Handler OptimizerParamsHandler works only with MLflowLogger")

Expand Down
Loading