diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 7429db6e3bdd1..4988dd4b02774 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -103,6 +103,10 @@ has_pytest = False +def freeze_rng_state(*args, **kwargs): + return torch.testing._utils.freeze_rng_state(*args, **kwargs) + + # Class to keep track of test flags configurable by environment variables. # Flags set here are intended to be read-only and should not be modified after # definition. diff --git a/torch/testing/_utils.py b/torch/testing/_utils.py index ed36aa9468eb2..d4697a93d4f27 100644 --- a/torch/testing/_utils.py +++ b/torch/testing/_utils.py @@ -1,4 +1,5 @@ import contextlib + import torch from torch.utils._mode_utils import no_dispatch @@ -6,6 +7,7 @@ # NB: these should all be importable without optional dependencies # (like numpy and expecttest). + def wrapper_set_seed(op, *args, **kwargs): """Wrapper to set seed manually for some functions like dropout See: https://github.com/pytorch/pytorch/pull/62315#issuecomment-896143189 for more details.