diff --git a/ignite/engine/__init__.py b/ignite/engine/__init__.py index e8c4e8b6d835..c4870fa6adc2 100644 --- a/ignite/engine/__init__.py +++ b/ignite/engine/__init__.py @@ -185,7 +185,7 @@ def supervised_training_step_amp( """ try: - from torch.amp import autocast, GradScaler + from torch.amp import autocast except ImportError: raise ImportError("Please install torch>=1.12.0 to use amp_mode='amp'.") @@ -412,7 +412,7 @@ def _check_arg( try: from torch.amp import GradScaler except ImportError: - raise ImportError("Please install torch>=1.6.0 to use scaler argument.") + raise ImportError("Please install torch>=2.3.1 to use scaler argument.") scaler = GradScaler(enabled=True) if on_tpu: diff --git a/tests/ignite/engine/test_create_supervised.py b/tests/ignite/engine/test_create_supervised.py index 6bd759f9c2b5..5a6255c5b51e 100644 --- a/tests/ignite/engine/test_create_supervised.py +++ b/tests/ignite/engine/test_create_supervised.py @@ -167,7 +167,7 @@ def _(): trainer.run(data) -@pytest.mark.skipif(Version(torch.__version__) < Version("1.12.0"), reason="Skip if < 1.12.0") +@pytest.mark.skipif(Version(torch.__version__) < Version("2.3.1"), reason="Skip if < 2.3.1") def test_create_supervised_training_scalar_assignment(): with mock.patch("ignite.engine._check_arg") as check_arg_mock: check_arg_mock.return_value = None, torch.amp.GradScaler(enabled=False) @@ -456,11 +456,11 @@ def test_create_supervised_trainer_amp_error(mock_torch_cuda_amp_module): _test_create_supervised_trainer_wrong_accumulation(trainer_device="cpu", amp_mode="amp") with pytest.raises(ImportError, match="Please install torch>=1.12.0 to use amp_mode='amp'."): _test_create_supervised_trainer(amp_mode="amp") - with pytest.raises(ImportError, match="Please install torch>=1.6.0 to use scaler argument."): + with pytest.raises(ImportError, match="Please install torch>=2.3.1 to use scaler argument."): _test_create_supervised_trainer(amp_mode="amp", scaler=True) -@pytest.mark.skipif(Version(torch.__version__) < Version("1.12.0"), reason="Skip if < 1.12.0") +@pytest.mark.skipif(Version(torch.__version__) < Version("2.3.1"), reason="Skip if < 2.3.1") def test_create_supervised_trainer_scaler_not_amp(): scaler = torch.amp.GradScaler(enabled=torch.cuda.is_available())