diff --git a/ignite/engine/__init__.py b/ignite/engine/__init__.py index b615a82a80cd..59b46d19d67b 100644 --- a/ignite/engine/__init__.py +++ b/ignite/engine/__init__.py @@ -1,3 +1,4 @@ +import collections from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union import torch @@ -27,7 +28,7 @@ def _prepare_batch( batch: Sequence[torch.Tensor], device: Optional[Union[str, torch.device]] = None, non_blocking: bool = False -): +) -> Tuple[Union[torch.Tensor, collections.Sequence, collections.Mapping, str, bytes], ...]: """Prepare batch for training: pass to a device with options. """ diff --git a/ignite/engine/deterministic.py b/ignite/engine/deterministic.py index ede760274e5b..4286b5acc662 100644 --- a/ignite/engine/deterministic.py +++ b/ignite/engine/deterministic.py @@ -2,7 +2,7 @@ import warnings from collections import OrderedDict from functools import wraps -from typing import Callable, Generator, Iterator, Optional +from typing import Any, Callable, Generator, Iterator, List, Optional, cast import torch from torch.utils.data import DataLoader @@ -61,7 +61,7 @@ def __init__(self, batch_sampler: BatchSampler, start_iteration: Optional[int] = if not isinstance(batch_sampler, BatchSampler): raise TypeError("Argument batch_sampler should be torch.utils.data.sampler.BatchSampler") - self.batch_indices = None + self.batch_indices = [] # type: List self.batch_sampler = batch_sampler self.start_iteration = start_iteration self.sampler = self.batch_sampler.sampler @@ -84,7 +84,7 @@ def __len__(self) -> int: return len(self.batch_sampler) -def _get_rng_states(): +def _get_rng_states() -> List[Any]: output = [random.getstate(), torch.get_rng_state()] try: import numpy as np @@ -96,7 +96,7 @@ def _get_rng_states(): return output -def _set_rng_states(rng_states): +def _set_rng_states(rng_states: List[Any]) -> None: random.setstate(rng_states[0]) torch.set_rng_state(rng_states[1]) try: @@ -107,14 +107,14 @@ def _set_rng_states(rng_states): pass -def _repr_rng_state(rng_states): +def _repr_rng_state(rng_states: List[Any]) -> str: from hashlib import md5 out = " ".join([md5(str(list(s)).encode("utf-8")).hexdigest() for s in rng_states]) return out -def keep_random_state(func: Callable): +def keep_random_state(func: Callable) -> Callable: """Helper decorator to keep random state of torch, numpy and random intact while executing a function. For more details on usage, please see :ref:`Dataflow synchronization`. @@ -123,7 +123,7 @@ def keep_random_state(func: Callable): """ @wraps(func) - def wrapper(*args, **kwargs): + def wrapper(*args: Any, **kwargs: Any) -> None: rng_states = _get_rng_states() func(*args, **kwargs) _set_rng_states(rng_states) @@ -181,16 +181,20 @@ def state_dict(self) -> OrderedDict: return state_dict def _init_run(self) -> None: - seed = torch.randint(0, int(1e9), (1,)).item() - self.state.seed = seed + self.state.seed = int(torch.randint(0, int(1e9), (1,)).item()) if not hasattr(self.state, "rng_states"): - self.state.rng_states = None + self.state.rng_states = None # type: ignore[attr-defined] if torch.cuda.is_available(): torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def _setup_engine(self) -> None: + if self.state.dataloader is None: + raise RuntimeError( + "Internal error, self.state.dataloader is None. Please, file an issue if you encounter this error." + ) + self._dataloader_len = self._get_data_length(self.state.dataloader) # if input data is torch dataloader we replace batch sampler by a batch sampler @@ -199,22 +203,24 @@ def _setup_engine(self) -> None: # attribute _dataset_kind is introduced since 1.3.0 => before 1.3.0 all datasets are map-like can_patch_dataloader = True if hasattr(self.state.dataloader, "_dataset_kind"): - from torch.utils.data.dataloader import _DatasetKind + from torch.utils.data.dataloader import _DatasetKind # type: ignore[attr-defined] - _dataloader_kind = self.state.dataloader._dataset_kind + _dataloader_kind = self.state.dataloader._dataset_kind # type: ignore[attr-defined] can_patch_dataloader = _dataloader_kind == _DatasetKind.Map if can_patch_dataloader: - if (self._dataloader_len is not None) and hasattr(self.state.dataloader.sampler, "epoch"): + if self._dataloader_len is not None and hasattr( + self.state.dataloader.sampler, "epoch" # type: ignore[attr-defined] + ): if self._dataloader_len != self.state.epoch_length: warnings.warn( "When defined engine's epoch length is different of input dataloader length, " "distributed sampler indices can not be setup in a reproducible manner" ) - batch_sampler = self.state.dataloader.batch_sampler + batch_sampler = self.state.dataloader.batch_sampler # type: ignore[attr-defined] if not (batch_sampler is None or isinstance(batch_sampler, ReproducibleBatchSampler)): self.state.dataloader = update_dataloader( - self.state.dataloader, ReproducibleBatchSampler(batch_sampler) + self.state.dataloader, ReproducibleBatchSampler(batch_sampler) # type: ignore[arg-type] ) iteration = self.state.iteration @@ -228,20 +234,24 @@ def _setup_engine(self) -> None: # restore rng state if in the middle in_the_middle = self.state.iteration % self._dataloader_len > 0 if self._dataloader_len is not None else False if (getattr(self.state, "rng_states", None) is not None) and in_the_middle: - _set_rng_states(self.state.rng_states) - self.state.rng_states = None + _set_rng_states(self.state.rng_states) # type: ignore[attr-defined] + self.state.rng_states = None # type: ignore[attr-defined] def _from_iteration(self, iteration: int) -> Iterator: + if self.state.dataloader is None: + raise RuntimeError( + "Internal error, self.state.dataloader is None. Please, file an issue if you encounter this error." + ) data = self.state.dataloader if isinstance(data, DataLoader): try: # following is unsafe for IterableDatasets - iteration %= len(data.batch_sampler) + iteration %= len(data.batch_sampler) # type: ignore[attr-defined, arg-type] # Synchronize dataflow according to state.iteration self._setup_seed() if iteration > 0: # batch sampler is ReproducibleBatchSampler - data.batch_sampler.start_iteration = iteration + data.batch_sampler.start_iteration = iteration # type: ignore[attr-defined, union-attr] return iter(data) except TypeError as e: # Probably we can do nothing with DataLoader built upon IterableDatasets @@ -249,7 +259,7 @@ def _from_iteration(self, iteration: int) -> Iterator: self.logger.info("Resuming from iteration for provided data will fetch data until required iteration ...") if hasattr(data, "__len__"): - iteration %= len(data) + iteration %= len(data) # type: ignore[arg-type] # Synchronize dataflow from the begining self._setup_seed(iteration=0) data_iter = iter(data) @@ -263,11 +273,11 @@ def _from_iteration(self, iteration: int) -> Iterator: return data_iter - def _setup_seed(self, _=None, iter_counter=None, iteration=None): + def _setup_seed(self, _: Any = None, iter_counter: Optional[int] = None, iteration: Optional[int] = None) -> None: if iter_counter is None: le = self._dataloader_len if self._dataloader_len is not None else 1 else: le = iter_counter if iteration is None: iteration = self.state.iteration - manual_seed(self.state.seed + iteration // le) + manual_seed(self.state.seed + iteration // le) # type: ignore[operator] diff --git a/ignite/engine/engine.py b/ignite/engine/engine.py index c9583c566880..3e28f8761e01 100644 --- a/ignite/engine/engine.py +++ b/ignite/engine/engine.py @@ -5,7 +5,9 @@ import weakref from collections import OrderedDict, defaultdict from collections.abc import Mapping -from typing import Any, Callable, Iterable, List, Optional, Union +from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union + +from torch.utils.data import DataLoader from ignite._utils import _to_hours_mins_secs from ignite.base import Serializable @@ -120,18 +122,18 @@ def compute_mean_std(engine, batch): _state_dict_one_of_opt_keys = ("iteration", "epoch") def __init__(self, process_function: Callable): - self._event_handlers = defaultdict(list) + self._event_handlers = defaultdict(list) # type: Dict[Any, List] self.logger = logging.getLogger(__name__ + "." + self.__class__.__name__) self._process_function = process_function - self.last_event_name = None + self.last_event_name = None # type: Optional[Events] self.should_terminate = False self.should_terminate_single_epoch = False self.state = State() - self._state_dict_user_keys = [] - self._allowed_events = [] + self._state_dict_user_keys = [] # type: List[str] + self._allowed_events = [] # type: List[EventEnum] - self._dataloader_iter = None - self._init_iter = [] + self._dataloader_iter = None # type: Optional[Iterator[Any]] + self._init_iter = [] # type: List[int] self.register_events(*Events) @@ -232,16 +234,16 @@ def _handler_wrapper(self, handler: Callable, event_name: Any, event_filter: Cal # signature of the following wrapper will be inspected during registering to check if engine is necessary # we have to build a wrapper with relevant signature : solution is functools.wraps @functools.wraps(handler) - def wrapper(*args, **kwargs) -> Any: + def wrapper(*args: Any, **kwargs: Any) -> Any: event = self.state.get_event_attrib_value(event_name) if event_filter(self, event): return handler(*args, **kwargs) # setup input handler as parent to make has_event_handler work - wrapper._parent = weakref.ref(handler) + wrapper._parent = weakref.ref(handler) # type: ignore[attr-defined] return wrapper - def add_event_handler(self, event_name: Any, handler: Callable, *args, **kwargs): + def add_event_handler(self, event_name: Any, handler: Callable, *args: Any, **kwargs: Any) -> RemovableEventHandle: """Add an event handler to be executed when the specified event is fired. Args: @@ -312,7 +314,7 @@ def execute_something(): return RemovableEventHandle(event_name, handler, self) @staticmethod - def _assert_non_filtered_event(event_name: Any): + def _assert_non_filtered_event(event_name: Any) -> None: if ( isinstance(event_name, CallableEventWithFilter) and event_name.filter != CallableEventWithFilter.default_event_filter @@ -321,7 +323,7 @@ def _assert_non_filtered_event(event_name: Any): "Argument event_name should not be a filtered event, " "please use event without any event filtering" ) - def has_event_handler(self, handler: Callable, event_name: Optional[Any] = None): + def has_event_handler(self, handler: Callable, event_name: Optional[Any] = None) -> bool: """Check if the specified event has the specified handler. Args: @@ -332,7 +334,7 @@ def has_event_handler(self, handler: Callable, event_name: Optional[Any] = None) if event_name is not None: if event_name not in self._event_handlers: return False - events = [event_name] + events = [event_name] # type: Union[List[Any], Dict[Any, List]] else: events = self._event_handlers for e in events: @@ -344,10 +346,10 @@ def has_event_handler(self, handler: Callable, event_name: Optional[Any] = None) @staticmethod def _compare_handlers(user_handler: Callable, registered_handler: Callable) -> bool: if hasattr(registered_handler, "_parent"): - registered_handler = registered_handler._parent() + registered_handler = registered_handler._parent() # type: ignore[attr-defined] return registered_handler == user_handler - def remove_event_handler(self, handler: Callable, event_name: Any): + def remove_event_handler(self, handler: Callable, event_name: Any) -> None: """Remove event handler `handler` from registered handlers of the engine Args: @@ -367,7 +369,7 @@ def remove_event_handler(self, handler: Callable, event_name: Any): raise ValueError("Input handler '{}' is not found among registered event handlers".format(handler)) self._event_handlers[event_name] = new_event_handlers - def on(self, event_name, *args, **kwargs): + def on(self, event_name: Any, *args: Any, **kwargs: Any) -> Callable: """Decorator shortcut for add_event_handler. Args: @@ -398,7 +400,7 @@ def decorator(f: Callable) -> Callable: return decorator - def _fire_event(self, event_name: Any, *event_args, **event_kwargs) -> None: + def _fire_event(self, event_name: Any, *event_args: Any, **event_kwargs: Any) -> None: """Execute all the handlers associated with given event. This method executes all handlers associated with the event @@ -460,7 +462,7 @@ def terminate_epoch(self) -> None: ) self.should_terminate_single_epoch = True - def _handle_exception(self, e: Exception) -> None: + def _handle_exception(self, e: BaseException) -> None: if Events.EXCEPTION_RAISED in self._event_handlers: self._fire_event(Events.EXCEPTION_RAISED, e) else: @@ -497,7 +499,7 @@ def save_engine(_): a dictionary containing engine's state """ - keys = self._state_dict_all_req_keys + (self._state_dict_one_of_opt_keys[0],) + keys = self._state_dict_all_req_keys + (self._state_dict_one_of_opt_keys[0],) # type: Tuple[str, ...] keys += tuple(self._state_dict_user_keys) return OrderedDict([(k, getattr(self.state, k)) for k in keys]) @@ -555,9 +557,9 @@ def load_state_dict(self, state_dict: Mapping) -> None: @staticmethod def _is_done(state: State) -> bool: - return state.iteration == state.epoch_length * state.max_epochs + return state.iteration == state.epoch_length * state.max_epochs # type: ignore[operator] - def set_data(self, data): + def set_data(self, data: Union[Iterable, DataLoader]) -> None: """Method to set data. After calling the method the next batch passed to `processing_function` is from newly provided data. Please, note that epoch length is not modified. @@ -705,21 +707,25 @@ def switch_batch(engine): return self._internal_run() @staticmethod - def _init_timers(state: State): + def _init_timers(state: State) -> None: state.times[Events.EPOCH_COMPLETED.name] = 0.0 state.times[Events.COMPLETED.name] = 0.0 - def _get_data_length(self, data): - data_length = None + def _get_data_length(self, data: Iterable) -> Optional[int]: try: if hasattr(data, "__len__"): - data_length = len(data) + return len(data) # type: ignore[arg-type] except TypeError: # _InfiniteConstantSampler can raise a TypeError on DataLoader length of a IterableDataset pass - return data_length + return None def _setup_engine(self) -> None: + if self.state.dataloader is None: + raise RuntimeError( + "Internal error, self.state.dataloader is None. Please, file an issue if you encounter this error." + ) + iteration = self.state.iteration self._dataloader_iter = iter(self.state.dataloader) @@ -734,7 +740,7 @@ def _internal_run(self) -> State: try: start_time = time.time() self._fire_event(Events.STARTED) - while self.state.epoch < self.state.max_epochs and not self.should_terminate: + while self.state.epoch < self.state.max_epochs and not self.should_terminate: # type: ignore[operator] self.state.epoch += 1 self._fire_event(Events.EPOCH_STARTED) @@ -785,6 +791,15 @@ def _run_once_on_dataset(self) -> float: iter_counter = self._init_iter.pop() if len(self._init_iter) > 0 else 0 should_exit = False try: + if self._dataloader_iter is None: + raise RuntimeError( + "Internal error, self._dataloader_iter is None. Please, file an issue if you encounter this error." + ) + if self.state.dataloader is None: + raise RuntimeError( + "Internal error, self.state.dataloader is None. Please, file an issue if you encounter this error." + ) + while True: try: # Avoid Events.GET_BATCH_STARTED triggered twice when data iter is restarted @@ -808,7 +823,8 @@ def _run_once_on_dataset(self) -> float: "Data iterator can not provide data anymore but required total number of " "iterations to run is not reached. " "Current iteration: {} vs Total iterations to run : {}".format( - self.state.iteration, self.state.epoch_length * self.state.max_epochs + self.state.iteration, + self.state.epoch_length * self.state.max_epochs, # type: ignore[operator] ) ) break diff --git a/ignite/engine/events.py b/ignite/engine/events.py index 818c491770e6..6ef19b7b4246 100644 --- a/ignite/engine/events.py +++ b/ignite/engine/events.py @@ -3,10 +3,15 @@ import weakref from enum import Enum from types import DynamicClassAttribute -from typing import Callable, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Iterator, List, Optional, Union + +from torch.utils.data import DataLoader from ignite.engine.utils import _check_signature +if TYPE_CHECKING: + from ignite.engine.engine import Engine + __all__ = ["CallableEventWithFilter", "EventEnum", "Events", "State"] @@ -23,7 +28,7 @@ class CallableEventWithFilter: """ - def __init__(self, value: str, event_filter: Optional[Callable] = None, name=None): + def __init__(self, value: str, event_filter: Optional[Callable] = None, name: Optional[str] = None) -> None: if event_filter is None: event_filter = CallableEventWithFilter.default_event_filter self.filter = event_filter @@ -36,12 +41,12 @@ def __init__(self, value: str, event_filter: Optional[Callable] = None, name=Non # copied to be compatible to enum @DynamicClassAttribute - def name(self): + def name(self) -> str: """The name of the Enum member.""" return self._name_ @DynamicClassAttribute - def value(self): + def value(self) -> str: """The value of the Enum member.""" return self._value_ @@ -92,7 +97,7 @@ def __call__( @staticmethod def every_event_filter(every: int) -> Callable: - def wrapper(engine, event: int) -> bool: + def wrapper(engine: "Engine", event: int) -> bool: if event % every == 0: return True return False @@ -101,7 +106,7 @@ def wrapper(engine, event: int) -> bool: @staticmethod def once_event_filter(once: int) -> Callable: - def wrapper(engine, event: int) -> bool: + def wrapper(engine: "Engine", event: int) -> bool: if event == once: return True return False @@ -109,13 +114,13 @@ def wrapper(engine, event: int) -> bool: return wrapper @staticmethod - def default_event_filter(engine, event: int) -> bool: + def default_event_filter(engine: "Engine", event: int) -> bool: return True def __str__(self) -> str: return "" % (self.name, self.filter) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if isinstance(other, CallableEventWithFilter): return self.name == other.name elif isinstance(other, str): @@ -123,16 +128,16 @@ def __eq__(self, other): else: return NotImplemented - def __hash__(self): + def __hash__(self) -> int: return hash(self._name_) - def __or__(self, other): + def __or__(self, other: Any) -> "EventsList": return EventsList() | self | other class CallableEvents(CallableEventWithFilter): # For backward compatibility - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super(CallableEvents, self).__init__(*args, **kwargs) warnings.warn( "Class ignite.engine.events.CallableEvents is deprecated. It will be removed in 0.5.0. " @@ -141,7 +146,7 @@ def __init__(self, *args, **kwargs): ) -class EventEnum(CallableEventWithFilter, Enum): +class EventEnum(CallableEventWithFilter, Enum): # type: ignore[misc] """Base class for all :class:`~ignite.engine.events.Events`. User defined custom events should also inherit this class. For example, Custom events based on the loss calculation and backward pass can be created as follows: @@ -288,7 +293,7 @@ class CustomEvents(EventEnum): TERMINATE = "terminate" TERMINATE_SINGLE_EPOCH = "terminate_single_epoch" - def __or__(self, other): + def __or__(self, other: Any) -> "EventsList": return EventsList() | self | other @@ -316,24 +321,24 @@ def call_on_events(engine): """ - def __init__(self): - self._events = [] + def __init__(self) -> None: + self._events = [] # type: List[Union[Events, CallableEventWithFilter]] - def _append(self, event: Union[Events, CallableEventWithFilter]): + def _append(self, event: Union[Events, CallableEventWithFilter]) -> None: if not isinstance(event, (Events, CallableEventWithFilter)): raise TypeError("Argument event should be Events or CallableEventWithFilter, got: {}".format(type(event))) self._events.append(event) - def __getitem__(self, item): + def __getitem__(self, item: int) -> Union[Events, CallableEventWithFilter]: return self._events[item] - def __iter__(self): + def __iter__(self) -> Iterator[Union[Events, CallableEventWithFilter]]: return iter(self._events) - def __len__(self): + def __len__(self) -> int: return len(self._events) - def __or__(self, other: Union[Events, CallableEventWithFilter]): + def __or__(self, other: Union[Events, CallableEventWithFilter]) -> "EventsList": self._append(event=other) return self @@ -369,29 +374,32 @@ class State: Events.COMPLETED: "epoch", } - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any) -> None: self.iteration = 0 self.epoch = 0 - self.epoch_length = None - self.max_epochs = None - self.output = None - self.batch = None - self.metrics = {} - self.dataloader = None - self.seed = None - self.times = {Events.EPOCH_COMPLETED.name: None, Events.COMPLETED.name: None} + self.epoch_length = None # type: Optional[int] + self.max_epochs = None # type: Optional[int] + self.output = None # type: Optional[int] + self.batch = None # type: Optional[int] + self.metrics = {} # type: Dict[str, Any] + self.dataloader = None # type: Optional[Union[DataLoader, Iterable[Any]]] + self.seed = None # type: Optional[int] + self.times = { + Events.EPOCH_COMPLETED.name: None, + Events.COMPLETED.name: None, + } # type: Dict[str, Optional[float]] for k, v in kwargs.items(): setattr(self, k, v) self._update_attrs() - def _update_attrs(self): + def _update_attrs(self) -> None: for value in self.event_to_attr.values(): if not hasattr(self, value): setattr(self, value, 0) - def get_event_attrib_value(self, event_name: Union[CallableEventWithFilter, Enum]) -> int: + def get_event_attrib_value(self, event_name: Events) -> int: if event_name not in State.event_to_attr: raise RuntimeError("Unknown event name '{}'".format(event_name)) return getattr(self, State.event_to_attr[event_name]) @@ -434,7 +442,9 @@ def print_epoch(engine): # print_epoch handler is now unregistered """ - def __init__(self, event_name: Union[CallableEventWithFilter, Enum, EventsList], handler: Callable, engine): + def __init__( + self, event_name: Union[CallableEventWithFilter, Enum, EventsList, Events], handler: Callable, engine: "Engine" + ) -> None: self.event_name = event_name self.handler = weakref.ref(handler) self.engine = weakref.ref(engine) @@ -455,8 +465,8 @@ def remove(self) -> None: if engine.has_event_handler(handler, self.event_name): engine.remove_event_handler(handler, self.event_name) - def __enter__(self): + def __enter__(self) -> "RemovableEventHandle": return self - def __exit__(self, *args, **kwargs) -> None: + def __exit__(self, *args: Any, **kwargs: Any) -> None: self.remove() diff --git a/ignite/engine/utils.py b/ignite/engine/utils.py index a09c0f3f6466..9c4c5b8d9846 100644 --- a/ignite/engine/utils.py +++ b/ignite/engine/utils.py @@ -1,11 +1,11 @@ import inspect -from typing import Callable +from typing import Any, Callable -def _check_signature(fn: Callable, fn_description: str, *args, **kwargs) -> None: +def _check_signature(fn: Callable, fn_description: str, *args: Any, **kwargs: Any) -> None: # if handler with filter, check the handler rather than the decorator if hasattr(fn, "_parent"): - signature = inspect.signature(fn._parent()) + signature = inspect.signature(fn._parent()) # type: ignore[attr-defined] else: signature = inspect.signature(fn) try: # try without engine diff --git a/ignite/handlers/__init__.py b/ignite/handlers/__init__.py index 9278d48dd795..973e1297fdca 100644 --- a/ignite/handlers/__init__.py +++ b/ignite/handlers/__init__.py @@ -1,7 +1,7 @@ from typing import Any, Callable, Union from ignite.engine import Engine -from ignite.engine.events import CallableEventWithFilter, EventEnum +from ignite.engine.events import Events from ignite.handlers.checkpoint import Checkpoint, DiskSaver, ModelCheckpoint from ignite.handlers.early_stopping import EarlyStopping from ignite.handlers.terminate_on_nan import TerminateOnNan @@ -30,7 +30,7 @@ def global_step_from_engine(engine: Engine, custom_event_name=None) -> Callable: global step """ - def wrapper(_: Any, event_name: Union[EventEnum, CallableEventWithFilter]): + def wrapper(_: Any, event_name: Events): if custom_event_name is not None: event_name = custom_event_name return engine.state.get_event_attrib_value(event_name) diff --git a/mypy.ini b/mypy.ini index 2778b2e7490e..33b53407df0c 100644 --- a/mypy.ini +++ b/mypy.ini @@ -7,10 +7,6 @@ show_error_codes = True ignore_errors = True -[mypy-ignite.engine.*] - -ignore_errors = True - [mypy-ignite.metrics.*] ignore_errors = True