Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions ignite/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,8 +564,8 @@ def run(
epoch_length (int, optional): Number of iterations to count as one epoch. By default, it can be set as
`len(data)`. If `data` is an iterator and `epoch_length` is not set, an error is raised.
This argument should be `None` if run is resuming from a state.
seed (int, optional): Seed to setup at each epoch for reproducible runs.
This argument should be `None` if run is resuming from a state.
seed (int, optional): Seed to use for dataflow consistency, by default it
will respect the global random state. This argument should be `None` if run is resuming from a state.

Returns:
State: output state.
Expand Down Expand Up @@ -594,7 +594,7 @@ def switch_batch(engine):
if max_epochs is None:
max_epochs = 1
if seed is None:
seed = 12
seed = torch.randint(0, int(1e9), (1,)).item()
if epoch_length is None:
if hasattr(data, "__len__"):
epoch_length = len(data)
Expand Down
33 changes: 33 additions & 0 deletions tests/ignite/engine/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch

from ignite.engine import Engine, Events, State
from ignite.metrics import Average


def test_terminate():
Expand Down Expand Up @@ -529,3 +530,35 @@ def foo(e, b):
engine.run(data_loader, epoch_length=10, max_epochs=5)

assert counter[0] == 50


def test_engine_random_state():
def random_data_generator():
while True:
yield torch.randint(0, 100, size=(5,))

def sum_data(engine, batch):
result = torch.sum(batch)
return result

def get_engine():
engine = Engine(sum_data)
average = Average()
average.attach(engine, "average")
return engine

torch.manual_seed(34)
engine = get_engine()
state1 = engine.run(random_data_generator(), max_epochs=2, epoch_length=2)

torch.manual_seed(34)
engine = get_engine()
state2 = engine.run(random_data_generator(), max_epochs=2, epoch_length=2)

torch.manual_seed(42)
engine = get_engine()
state3 = engine.run(random_data_generator(), max_epochs=2, epoch_length=2)

assert state1.metrics["average"] == pytest.approx(state2.metrics["average"])
assert state1.metrics["average"] != pytest.approx(state3.metrics["average"])
assert state2.metrics["average"] != pytest.approx(state3.metrics["average"])