diff --git a/ignite/engine/engine.py b/ignite/engine/engine.py index e4911aaf0aaf..b1d386500540 100644 --- a/ignite/engine/engine.py +++ b/ignite/engine/engine.py @@ -564,8 +564,8 @@ def run( epoch_length (int, optional): Number of iterations to count as one epoch. By default, it can be set as `len(data)`. If `data` is an iterator and `epoch_length` is not set, an error is raised. This argument should be `None` if run is resuming from a state. - seed (int, optional): Seed to setup at each epoch for reproducible runs. - This argument should be `None` if run is resuming from a state. + seed (int, optional): Seed to use for dataflow consistency, by default it + will respect the global random state. This argument should be `None` if run is resuming from a state. Returns: State: output state. @@ -594,7 +594,7 @@ def switch_batch(engine): if max_epochs is None: max_epochs = 1 if seed is None: - seed = 12 + seed = torch.randint(0, int(1e9), (1,)).item() if epoch_length is None: if hasattr(data, "__len__"): epoch_length = len(data) diff --git a/tests/ignite/engine/test_engine.py b/tests/ignite/engine/test_engine.py index 3dfd15a52432..bd66f6dc5a60 100644 --- a/tests/ignite/engine/test_engine.py +++ b/tests/ignite/engine/test_engine.py @@ -6,6 +6,7 @@ import torch from ignite.engine import Engine, Events, State +from ignite.metrics import Average def test_terminate(): @@ -529,3 +530,35 @@ def foo(e, b): engine.run(data_loader, epoch_length=10, max_epochs=5) assert counter[0] == 50 + + +def test_engine_random_state(): + def random_data_generator(): + while True: + yield torch.randint(0, 100, size=(5,)) + + def sum_data(engine, batch): + result = torch.sum(batch) + return result + + def get_engine(): + engine = Engine(sum_data) + average = Average() + average.attach(engine, "average") + return engine + + torch.manual_seed(34) + engine = get_engine() + state1 = engine.run(random_data_generator(), max_epochs=2, epoch_length=2) + + torch.manual_seed(34) + engine = get_engine() + state2 = engine.run(random_data_generator(), max_epochs=2, epoch_length=2) + + torch.manual_seed(42) + engine = get_engine() + state3 = engine.run(random_data_generator(), max_epochs=2, epoch_length=2) + + assert state1.metrics["average"] == pytest.approx(state2.metrics["average"]) + assert state1.metrics["average"] != pytest.approx(state3.metrics["average"]) + assert state2.metrics["average"] != pytest.approx(state3.metrics["average"])