From 5093b9533f030eebd0036781d97d0cef58bbd7e8 Mon Sep 17 00:00:00 2001 From: gruebel Date: Tue, 22 Dec 2020 13:11:44 +0100 Subject: [PATCH 1/2] Add None check for max_epochs --- ignite/contrib/handlers/param_scheduler.py | 3 ++- ignite/engine/engine.py | 12 +++++++++--- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/ignite/contrib/handlers/param_scheduler.py b/ignite/contrib/handlers/param_scheduler.py index 11e2b97f8e44..2e9ae5894a53 100644 --- a/ignite/contrib/handlers/param_scheduler.py +++ b/ignite/contrib/handlers/param_scheduler.py @@ -65,7 +65,8 @@ def __call__(self, engine: Optional[Engine], name: Optional[str] = None) -> None if isinstance(value, list): if len(value) != len(self.optimizer_param_groups): raise ValueError( - f"size of value is different than optimizer_param_groups {len(value)} != {len(self.optimizer_param_groups)}" + "size of value is different than optimizer_param_groups " + f"{len(value)} != {len(self.optimizer_param_groups)}" ) for i, param_group in enumerate(self.optimizer_param_groups): diff --git a/ignite/engine/engine.py b/ignite/engine/engine.py index 4b343465a7b0..9b93af56acf3 100644 --- a/ignite/engine/engine.py +++ b/ignite/engine/engine.py @@ -561,7 +561,8 @@ def _is_done(state: State) -> bool: is_done_iters = state.max_iters is not None and state.iteration >= state.max_iters is_done_count = ( state.epoch_length is not None - and state.iteration >= state.epoch_length * state.max_epochs # type: ignore[operator] + and state.max_epochs is not None + and state.iteration >= state.epoch_length * state.max_epochs ) is_done_epochs = state.max_epochs is not None and state.epoch >= state.max_epochs return is_done_iters or is_done_count or is_done_epochs @@ -833,12 +834,17 @@ def _run_once_on_dataset(self) -> float: # Should exit while loop if we can not iterate if should_exit: if not self._is_done(self.state): + total_iters = ( + self.state.epoch_length * self.state.max_epochs + if self.state.max_epochs is not None + else None + ) + warnings.warn( "Data iterator can not provide data anymore but required total number of " "iterations to run is not reached. " "Current iteration: {} vs Total iterations to run : {}".format( - self.state.iteration, - self.state.epoch_length * self.state.max_epochs, # type: ignore[operator] + self.state.iteration, total_iters, ) ) break From 0a87d54946bdc5cda89c57a7b83df2447276f41c Mon Sep 17 00:00:00 2001 From: gruebel Date: Tue, 22 Dec 2020 14:30:41 +0100 Subject: [PATCH 2/2] Small CR fix --- ignite/engine/engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ignite/engine/engine.py b/ignite/engine/engine.py index 9b93af56acf3..2171618463f9 100644 --- a/ignite/engine/engine.py +++ b/ignite/engine/engine.py @@ -837,7 +837,7 @@ def _run_once_on_dataset(self) -> float: total_iters = ( self.state.epoch_length * self.state.max_epochs if self.state.max_epochs is not None - else None + else self.state.max_iters ) warnings.warn(