diff --git a/test/test_ops.py b/test/test_ops.py index 858880e35bc..82f8b6b6eb2 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -845,7 +845,9 @@ def test_frozenbatchnorm2d_repr(self): expected_string = f"FrozenBatchNorm2d({num_features}, eps={eps})" assert repr(t) == expected_string - def test_frozenbatchnorm2d_eps(self): + @pytest.mark.parametrize("seed", range(10)) + def test_frozenbatchnorm2d_eps(self, seed): + torch.random.manual_seed(seed) sample_size = (4, 32, 28, 28) x = torch.rand(sample_size) state_dict = dict(