Skip to content

Commit

Permalink
Fixed issue when DATALOADER_STOP_ITERATION event is triggered when en…
Browse files Browse the repository at this point in the history
…gine.run(data=None, ...) (#3217)
  • Loading branch information
vfdev-5 committed Mar 24, 2024
1 parent 2d3f42a commit df819ca
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 15 deletions.
44 changes: 32 additions & 12 deletions ignite/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1037,15 +1037,23 @@ def _run_once_on_dataset_as_gen(self) -> Generator[State, None, float]:

while True:
self.state.batch = self.state.output = None

try:
# Avoid Events.GET_BATCH_STARTED triggered twice when data iter is restarted
if self.last_event_name != Events.DATALOADER_STOP_ITERATION:
self._fire_event(Events.GET_BATCH_STARTED)
yield from self._maybe_terminate_or_interrupt()
# We should not trigger GET_BATCH_STARTED, GET_BATCH_COMPLETED, DATALOADER_STOP_ITERATION events
# if no data was provided to engine.run(data=None, ...)
if self.state.dataloader is not None:
self._fire_event(Events.GET_BATCH_STARTED)
yield from self._maybe_terminate_or_interrupt()

self.state.batch = next(self._dataloader_iter)
self._fire_event(Events.GET_BATCH_COMPLETED)
yield from self._maybe_terminate_or_interrupt()

# We should not trigger GET_BATCH_STARTED, GET_BATCH_COMPLETED, DATALOADER_STOP_ITERATION events
# if no data was provided to engine.run(data=None, ...)
if self.state.dataloader is not None:
self._fire_event(Events.GET_BATCH_COMPLETED)
yield from self._maybe_terminate_or_interrupt()

iter_counter += 1
should_exit = False
Expand Down Expand Up @@ -1074,8 +1082,11 @@ def _run_once_on_dataset_as_gen(self) -> Generator[State, None, float]:
)
break

self._fire_event(Events.DATALOADER_STOP_ITERATION)
yield from self._maybe_terminate_or_interrupt()
# We should not trigger GET_BATCH_STARTED, GET_BATCH_COMPLETED, DATALOADER_STOP_ITERATION events
# if no data was provided to engine.run(data=None, ...)
if self.state.dataloader is not None:
self._fire_event(Events.DATALOADER_STOP_ITERATION)
yield from self._maybe_terminate_or_interrupt()

self._setup_dataloader_iter()
should_exit = True
Expand Down Expand Up @@ -1198,12 +1209,18 @@ def _run_once_on_dataset_legacy(self) -> float:
try:
# Avoid Events.GET_BATCH_STARTED triggered twice when data iter is restarted
if self.last_event_name != Events.DATALOADER_STOP_ITERATION:
self._fire_event(Events.GET_BATCH_STARTED)
self._maybe_terminate_legacy()
# We should not trigger GET_BATCH_STARTED, GET_BATCH_COMPLETED, DATALOADER_STOP_ITERATION events
# if no data was provided to engine.run(data=None, ...)
if self.state.dataloader is not None:
self._fire_event(Events.GET_BATCH_STARTED)
self._maybe_terminate_legacy()

self.state.batch = next(self._dataloader_iter)
self._fire_event(Events.GET_BATCH_COMPLETED)
self._maybe_terminate_legacy()
# We should not trigger GET_BATCH_STARTED, GET_BATCH_COMPLETED, DATALOADER_STOP_ITERATION events
# if no data was provided to engine.run(data=None, ...)
if self.state.dataloader is not None:
self._fire_event(Events.GET_BATCH_COMPLETED)
self._maybe_terminate_legacy()

iter_counter += 1
should_exit = False
Expand Down Expand Up @@ -1232,8 +1249,11 @@ def _run_once_on_dataset_legacy(self) -> float:
)
break

self._fire_event(Events.DATALOADER_STOP_ITERATION)
self._maybe_terminate_legacy()
# We should not trigger GET_BATCH_STARTED, GET_BATCH_COMPLETED, DATALOADER_STOP_ITERATION events
# if no data was provided to engine.run(data=None, ...)
if self.state.dataloader is not None:
self._fire_event(Events.DATALOADER_STOP_ITERATION)
self._maybe_terminate_legacy()

self._setup_dataloader_iter()
should_exit = True
Expand Down
25 changes: 22 additions & 3 deletions tests/ignite/engine/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,8 +601,8 @@ def _test_check_triggered_events(self, data, max_epochs, epoch_length, exp_iter_
Events.EPOCH_COMPLETED: max_epochs,
Events.ITERATION_STARTED: max_epochs * epoch_length,
Events.ITERATION_COMPLETED: max_epochs * epoch_length,
Events.GET_BATCH_STARTED: max_epochs * epoch_length,
Events.GET_BATCH_COMPLETED: max_epochs * epoch_length,
Events.GET_BATCH_STARTED: max_epochs * epoch_length if data is not None else 0,
Events.GET_BATCH_COMPLETED: max_epochs * epoch_length if data is not None else 0,
Events.DATALOADER_STOP_ITERATION: (max_epochs - 1) if exp_iter_stops is None else exp_iter_stops,
}

Expand All @@ -617,7 +617,7 @@ def _test_run_check_triggered_events(self):
self._test_check_triggered_events(
list(range(100)), max_epochs=5, epoch_length=150, exp_iter_stops=150 * 5 // 100
)
self._test_check_triggered_events(None, max_epochs=5, epoch_length=150)
self._test_check_triggered_events(None, max_epochs=5, epoch_length=150, exp_iter_stops=0)

def test_run_check_triggered_events_list(self):
self._test_run_check_triggered_events()
Expand Down Expand Up @@ -1146,6 +1146,25 @@ def train_step(engine, batch):
assert trainer.state.epoch == 20
assert trainer.state.dataloader is None

def test_engine_no_data_events(self):
# Reproduces the issue https://github.com/pytorch/ignite/issues/3190
max_epochs = 4
dataset = range(10)

def training_step(engine, _):
assert engine.state.dataloader is None

trainer = Engine(training_step)
trainer.state.dataiter = iter(dataset)

@trainer.on(Events.DATALOADER_STOP_ITERATION)
@trainer.on(Events.GET_BATCH_STARTED)
@trainer.on(Events.GET_BATCH_COMPLETED)
def should_not_be_called():
assert False, trainer.last_event_name

trainer.run(max_epochs=max_epochs, epoch_length=4)

@pytest.mark.parametrize("data, epoch_length", [(None, 10), (range(10), None)])
def test_engine_run_resume(self, data, epoch_length):
# https://github.com/pytorch/ignite/wiki/Roadmap#runresume-logic-improvements
Expand Down

0 comments on commit df819ca

Please sign in to comment.