Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 19 additions & 16 deletions ignite/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,11 +552,12 @@ def run(
data (Iterable): Collection of batches allowing repeated iteration (e.g., list or `DataLoader`).
max_epochs (int, optional): Max epochs to run for (default: None).
If a new state should be created (first run or run again from ended engine), it's default value is 1.
This argument should be `None` if run is resuming from a state.
If run is resuming from a state, provided `max_epochs` will be taken into account and should be larger
than `engine.state.max_epochs`.
epoch_length (int, optional): Number of iterations to count as one epoch. By default, it can be set as
`len(data)`. If `data` is an iterator and `epoch_length` is not set, then it will be automatically
determined as the iteration on which data iterator raises `StopIteration`.
This argument should be `None` if run is resuming from a state.
This argument should not change if run is resuming from a state.
seed (int, optional): Deprecated argument. Please, use `torch.manual_seed` or
:meth:`~ignite.utils.manual_seed`.

Expand All @@ -582,20 +583,8 @@ def switch_batch(engine):
"Please, use torch.manual_seed or ignite.utils.manual_seed"
)

if self.state is None or self._is_done(self.state):
# Create new state
if max_epochs is None:
max_epochs = 1
if epoch_length is None:
if hasattr(data, "__len__"):
epoch_length = len(data)
if epoch_length < 1:
raise ValueError("Input data has zero size. Please provide non-empty data")

self.state = State(iteration=0, epoch=0, max_epochs=max_epochs, epoch_length=epoch_length)
self.logger.info("Engine run starting with max_epochs={}.".format(max_epochs))
else:
# Keep actual state and override it if input args provided
if self.state is not None:
# Check and apply overridden parameters
if max_epochs is not None:
if max_epochs < self.state.epoch:
raise ValueError(
Expand All @@ -610,6 +599,20 @@ def switch_batch(engine):
epoch_length, self.state.epoch_length
)
)

if self.state is None or self._is_done(self.state):
# Create new state
if max_epochs is None:
max_epochs = 1
if epoch_length is None:
if hasattr(data, "__len__"):
epoch_length = len(data)
if epoch_length < 1:
raise ValueError("Input data has zero size. Please provide non-empty data")

self.state = State(iteration=0, epoch=0, max_epochs=max_epochs, epoch_length=epoch_length)
self.logger.info("Engine run starting with max_epochs={}.".format(max_epochs))
else:
self.logger.info(
"Engine run resuming from iteration {}, epoch {} until {} epochs".format(
self.state.iteration, self.state.epoch, self.state.max_epochs
Expand Down
20 changes: 20 additions & 0 deletions tests/ignite/engine/test_engine_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,26 @@ def test_load_state_dict_with_params_overriding_integration():
engine.run(data, epoch_length=90)


def test_continue_training():
# Tests issue : https://github.com/pytorch/ignite/issues/993
max_epochs = 2
data = range(10)
engine = Engine(lambda e, b: 1)
state = engine.run(data, max_epochs=max_epochs)
assert state.max_epochs == max_epochs
assert state.iteration == len(data) * max_epochs
assert state.epoch == max_epochs

@engine.on(Events.STARTED)
def assert_continue_training():
assert engine.state.epoch == max_epochs

state = engine.run(data, max_epochs=max_epochs * 2)
assert state.max_epochs == max_epochs * 2
assert state.iteration == len(data) * max_epochs * 2
assert state.epoch == max_epochs * 2


def test_state_dict_with_user_keys_integration(dirname):
engine = Engine(lambda e, b: 1)
engine.state_dict_user_keys.append("alpha")
Expand Down