Skip to content

Commit

Permalink
opcheck should be usable without optional dependencies (#127292)
Browse files Browse the repository at this point in the history
This PR excises opcheck's dependency on
torch.testing._internal.common_utils, (which comes with dependencies on
expecttest and hypothesis). We do this by moving what we need to
torch.testing._utils and adding a test for it.

Fixes #126870, #126871

Test Plan:
- new tests
Pull Request resolved: #127292
Approved by: https://github.com/williamwen42
ghstack dependencies: #127291
  • Loading branch information
zou3519 authored and pytorchmergebot committed May 29, 2024
1 parent 8a31c2a commit 28de914
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 49 deletions.
15 changes: 15 additions & 0 deletions test/test_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3154,6 +3154,21 @@ def test_opcheck_bad_op(self):
},
)

def test_opcheck_does_not_require_extra_deps(self):
# torch.testing._internal.common_utils comes with a lot of additional
# test-time dependencies. Since opcheck is public API, it should be
# usable only with pytorch install-time dependencies.
cmd = [
sys.executable,
"-c",
"import torch; import sys; \
x = torch.randn(3, requires_grad=True); \
torch.library.opcheck(torch.ops.aten.sin.default, (x,)); \
assert 'expecttest' not in sys.modules; \
assert 'torch.testing._internal.common_utils' not in sys.modules",
]
subprocess.check_output(cmd, shell=False)


only_for = ("cpu", "cuda")
instantiate_device_type_tests(TestCustomOpTesting, globals(), only_for=only_for)
Expand Down
1 change: 1 addition & 0 deletions torch/testing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from torch._C import FileCheck as FileCheck
from . import _utils
from ._comparison import assert_allclose, assert_close as assert_close
from ._creation import make_tensor as make_tensor
19 changes: 2 additions & 17 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@
make_fullrank_matrices_with_distinct_singular_values,
TEST_WITH_ROCM, IS_WINDOWS, IS_MACOS, TEST_SCIPY,
torch_to_numpy_dtype_dict, numpy_to_torch_dtype, TEST_WITH_ASAN,
GRADCHECK_NONDET_TOL, freeze_rng_state, slowTest, TEST_WITH_SLOW,
GRADCHECK_NONDET_TOL, slowTest, TEST_WITH_SLOW,
TEST_WITH_TORCHINDUCTOR
)
from torch.testing._utils import wrapper_set_seed

import torch._refs as refs # noqa: F401
import torch._refs.nn.functional
Expand Down Expand Up @@ -11299,22 +11300,6 @@ def reference_mse_loss(input, target, reduction="mean"):
return se


def wrapper_set_seed(op, *args, **kwargs):
"""Wrapper to set seed manually for some functions like dropout
See: https://github.com/pytorch/pytorch/pull/62315#issuecomment-896143189 for more details.
"""
with freeze_rng_state():
torch.manual_seed(42)
output = op(*args, **kwargs)

if isinstance(output, torch.Tensor) and output.device.type == "lazy":
# We need to call mark step inside freeze_rng_state so that numerics
# match eager execution
torch._lazy.mark_step()

return output


def reference_layer_norm(inp: np.ndarray, normalized_shape: Tuple[int], weight=None, bias=None, eps=1e-5):
return reference_native_layer_norm(inp, normalized_shape, weight, bias, eps)[0]

Expand Down
34 changes: 4 additions & 30 deletions torch/testing/_internal/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,17 @@
from torch.utils._import_utils import _check_module_exists
import torch.utils._pytree as pytree

from .composite_compliance import no_dispatch
try:
import pytest
has_pytest = True
except ImportError:
has_pytest = False


def freeze_rng_state(*args, **kwargs):
return torch.testing._utils.freeze_rng_state(*args, **kwargs)


# Class to keep track of test flags configurable by environment variables.
# Flags set here are intended to be read-only and should not be modified after
# definition.
Expand Down Expand Up @@ -1949,35 +1952,6 @@ def set_rng_seed(seed):
np.random.seed(seed)


disable_functorch = torch._C._DisableFuncTorch


@contextlib.contextmanager
def freeze_rng_state():
# no_dispatch needed for test_composite_compliance
# Some OpInfos use freeze_rng_state for rng determinism, but
# test_composite_compliance overrides dispatch for all torch functions
# which we need to disable to get and set rng state
with no_dispatch(), disable_functorch():
rng_state = torch.get_rng_state()
if torch.cuda.is_available():
cuda_rng_state = torch.cuda.get_rng_state()
try:
yield
finally:
# Modes are not happy with torch.cuda.set_rng_state
# because it clones the state (which could produce a Tensor Subclass)
# and then grabs the new tensor's data pointer in generator.set_state.
#
# In the long run torch.cuda.set_rng_state should probably be
# an operator.
#
# NB: Mode disable is to avoid running cross-ref tests on thes seeding
with no_dispatch(), disable_functorch():
if torch.cuda.is_available():
torch.cuda.set_rng_state(cuda_rng_state)
torch.set_rng_state(rng_state)

@contextlib.contextmanager
def set_default_dtype(dtype):
saved_dtype = torch.get_default_dtype()
Expand Down
2 changes: 1 addition & 1 deletion torch/testing/_internal/optests/aot_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch
import torch.utils._pytree as pytree
from torch.testing._internal.common_methods_invocations import wrapper_set_seed
from torch.testing._utils import wrapper_set_seed
from functorch.compile import compiled_function, min_cut_rematerialization_partition, nop
from .make_fx import randomize
import re
Expand Down
2 changes: 1 addition & 1 deletion torch/testing/_internal/optests/make_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_methods_invocations import wrapper_set_seed
from torch.testing._utils import wrapper_set_seed
import torch.utils._pytree as pytree


Expand Down
50 changes: 50 additions & 0 deletions torch/testing/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import contextlib

import torch

# Common testing utilities for use in public testing APIs.
# NB: these should all be importable without optional dependencies
# (like numpy and expecttest).


def wrapper_set_seed(op, *args, **kwargs):
"""Wrapper to set seed manually for some functions like dropout
See: https://github.com/pytorch/pytorch/pull/62315#issuecomment-896143189 for more details.
"""
with freeze_rng_state():
torch.manual_seed(42)
output = op(*args, **kwargs)

if isinstance(output, torch.Tensor) and output.device.type == "lazy":
# We need to call mark step inside freeze_rng_state so that numerics
# match eager execution
torch._lazy.mark_step() # type: ignore[attr-defined]

return output


@contextlib.contextmanager
def freeze_rng_state():
# no_dispatch needed for test_composite_compliance
# Some OpInfos use freeze_rng_state for rng determinism, but
# test_composite_compliance overrides dispatch for all torch functions
# which we need to disable to get and set rng state
with torch.utils._mode_utils.no_dispatch(), torch._C._DisableFuncTorch():
rng_state = torch.get_rng_state()
if torch.cuda.is_available():
cuda_rng_state = torch.cuda.get_rng_state()
try:
yield
finally:
# Modes are not happy with torch.cuda.set_rng_state
# because it clones the state (which could produce a Tensor Subclass)
# and then grabs the new tensor's data pointer in generator.set_state.
#
# In the long run torch.cuda.set_rng_state should probably be
# an operator.
#
# NB: Mode disable is to avoid running cross-ref tests on thes seeding
with torch.utils._mode_utils.no_dispatch(), torch._C._DisableFuncTorch():
if torch.cuda.is_available():
torch.cuda.set_rng_state(cuda_rng_state) # type: ignore[possibly-undefined]
torch.set_rng_state(rng_state)

0 comments on commit 28de914

Please sign in to comment.