-
Notifications
You must be signed in to change notification settings - Fork 21.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
opcheck should be usable without optional dependencies (#127292)
This PR excises opcheck's dependency on torch.testing._internal.common_utils, (which comes with dependencies on expecttest and hypothesis). We do this by moving what we need to torch.testing._utils and adding a test for it. Fixes #126870, #126871 Test Plan: - new tests Pull Request resolved: #127292 Approved by: https://github.com/williamwen42 ghstack dependencies: #127291
- Loading branch information
1 parent
8a31c2a
commit 28de914
Showing
7 changed files
with
74 additions
and
49 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from torch._C import FileCheck as FileCheck | ||
from . import _utils | ||
from ._comparison import assert_allclose, assert_close as assert_close | ||
from ._creation import make_tensor as make_tensor |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import contextlib | ||
|
||
import torch | ||
|
||
# Common testing utilities for use in public testing APIs. | ||
# NB: these should all be importable without optional dependencies | ||
# (like numpy and expecttest). | ||
|
||
|
||
def wrapper_set_seed(op, *args, **kwargs): | ||
"""Wrapper to set seed manually for some functions like dropout | ||
See: https://github.com/pytorch/pytorch/pull/62315#issuecomment-896143189 for more details. | ||
""" | ||
with freeze_rng_state(): | ||
torch.manual_seed(42) | ||
output = op(*args, **kwargs) | ||
|
||
if isinstance(output, torch.Tensor) and output.device.type == "lazy": | ||
# We need to call mark step inside freeze_rng_state so that numerics | ||
# match eager execution | ||
torch._lazy.mark_step() # type: ignore[attr-defined] | ||
|
||
return output | ||
|
||
|
||
@contextlib.contextmanager | ||
def freeze_rng_state(): | ||
# no_dispatch needed for test_composite_compliance | ||
# Some OpInfos use freeze_rng_state for rng determinism, but | ||
# test_composite_compliance overrides dispatch for all torch functions | ||
# which we need to disable to get and set rng state | ||
with torch.utils._mode_utils.no_dispatch(), torch._C._DisableFuncTorch(): | ||
rng_state = torch.get_rng_state() | ||
if torch.cuda.is_available(): | ||
cuda_rng_state = torch.cuda.get_rng_state() | ||
try: | ||
yield | ||
finally: | ||
# Modes are not happy with torch.cuda.set_rng_state | ||
# because it clones the state (which could produce a Tensor Subclass) | ||
# and then grabs the new tensor's data pointer in generator.set_state. | ||
# | ||
# In the long run torch.cuda.set_rng_state should probably be | ||
# an operator. | ||
# | ||
# NB: Mode disable is to avoid running cross-ref tests on thes seeding | ||
with torch.utils._mode_utils.no_dispatch(), torch._C._DisableFuncTorch(): | ||
if torch.cuda.is_available(): | ||
torch.cuda.set_rng_state(cuda_rng_state) # type: ignore[possibly-undefined] | ||
torch.set_rng_state(rng_state) |