Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

opcheck should be usable without optional dependencies #127292

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions test/test_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3153,6 +3153,21 @@ def test_opcheck_bad_op(self):
},
)

def test_opcheck_does_not_require_extra_deps(self):
# torch.testing._internal.common_utils comes with a lot of additional
# test-time dependencies. Since opcheck is public API, it should be
# usable only with pytorch install-time dependencies.
cmd = [
sys.executable,
"-c",
"import torch; import sys; \
x = torch.randn(3, requires_grad=True); \
torch.library.opcheck(torch.ops.aten.sin.default, (x,)); \
assert 'expecttest' not in sys.modules; \
assert 'torch.testing._internal.common_utils' not in sys.modules",
]
subprocess.check_output(cmd, shell=False)


only_for = ("cpu", "cuda")
instantiate_device_type_tests(TestCustomOpTesting, globals(), only_for=only_for)
Expand Down
1 change: 1 addition & 0 deletions torch/testing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from torch._C import FileCheck as FileCheck
from . import _utils
from ._comparison import assert_allclose, assert_close as assert_close
from ._creation import make_tensor as make_tensor
19 changes: 2 additions & 17 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@
make_fullrank_matrices_with_distinct_singular_values,
TEST_WITH_ROCM, IS_WINDOWS, IS_MACOS, TEST_SCIPY,
torch_to_numpy_dtype_dict, numpy_to_torch_dtype, TEST_WITH_ASAN,
GRADCHECK_NONDET_TOL, freeze_rng_state, slowTest, TEST_WITH_SLOW,
GRADCHECK_NONDET_TOL, slowTest, TEST_WITH_SLOW,
TEST_WITH_TORCHINDUCTOR
)
from torch.testing._utils import wrapper_set_seed

import torch._refs as refs # noqa: F401
import torch._refs.nn.functional
Expand Down Expand Up @@ -11299,22 +11300,6 @@ def reference_mse_loss(input, target, reduction="mean"):
return se


def wrapper_set_seed(op, *args, **kwargs):
"""Wrapper to set seed manually for some functions like dropout
See: https://github.com/pytorch/pytorch/pull/62315#issuecomment-896143189 for more details.
"""
with freeze_rng_state():
torch.manual_seed(42)
output = op(*args, **kwargs)

if isinstance(output, torch.Tensor) and output.device.type == "lazy":
# We need to call mark step inside freeze_rng_state so that numerics
# match eager execution
torch._lazy.mark_step()

return output


def reference_layer_norm(inp: np.ndarray, normalized_shape: Tuple[int], weight=None, bias=None, eps=1e-5):
return reference_native_layer_norm(inp, normalized_shape, weight, bias, eps)[0]

Expand Down
34 changes: 4 additions & 30 deletions torch/testing/_internal/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,17 @@
from torch.utils._import_utils import _check_module_exists
import torch.utils._pytree as pytree

from .composite_compliance import no_dispatch
try:
import pytest
has_pytest = True
except ImportError:
has_pytest = False


def freeze_rng_state(*args, **kwargs):
return torch.testing._utils.freeze_rng_state(*args, **kwargs)


# Class to keep track of test flags configurable by environment variables.
# Flags set here are intended to be read-only and should not be modified after
# definition.
Expand Down Expand Up @@ -1949,35 +1952,6 @@ def set_rng_seed(seed):
np.random.seed(seed)


disable_functorch = torch._C._DisableFuncTorch


@contextlib.contextmanager
def freeze_rng_state():
# no_dispatch needed for test_composite_compliance
# Some OpInfos use freeze_rng_state for rng determinism, but
# test_composite_compliance overrides dispatch for all torch functions
# which we need to disable to get and set rng state
with no_dispatch(), disable_functorch():
rng_state = torch.get_rng_state()
if torch.cuda.is_available():
cuda_rng_state = torch.cuda.get_rng_state()
try:
yield
finally:
# Modes are not happy with torch.cuda.set_rng_state
# because it clones the state (which could produce a Tensor Subclass)
# and then grabs the new tensor's data pointer in generator.set_state.
#
# In the long run torch.cuda.set_rng_state should probably be
# an operator.
#
# NB: Mode disable is to avoid running cross-ref tests on thes seeding
with no_dispatch(), disable_functorch():
if torch.cuda.is_available():
torch.cuda.set_rng_state(cuda_rng_state)
torch.set_rng_state(rng_state)

@contextlib.contextmanager
def set_default_dtype(dtype):
saved_dtype = torch.get_default_dtype()
Expand Down
2 changes: 1 addition & 1 deletion torch/testing/_internal/optests/aot_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch
import torch.utils._pytree as pytree
from torch.testing._internal.common_methods_invocations import wrapper_set_seed
from torch.testing._utils import wrapper_set_seed
from functorch.compile import compiled_function, min_cut_rematerialization_partition, nop
from .make_fx import randomize
import re
Expand Down
2 changes: 1 addition & 1 deletion torch/testing/_internal/optests/make_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_methods_invocations import wrapper_set_seed
from torch.testing._utils import wrapper_set_seed
import torch.utils._pytree as pytree


Expand Down
50 changes: 50 additions & 0 deletions torch/testing/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import contextlib

import torch

# Common testing utilities for use in public testing APIs.
# NB: these should all be importable without optional dependencies
# (like numpy and expecttest).


def wrapper_set_seed(op, *args, **kwargs):
"""Wrapper to set seed manually for some functions like dropout
See: https://github.com/pytorch/pytorch/pull/62315#issuecomment-896143189 for more details.
"""
with freeze_rng_state():
torch.manual_seed(42)
output = op(*args, **kwargs)

if isinstance(output, torch.Tensor) and output.device.type == "lazy":
# We need to call mark step inside freeze_rng_state so that numerics
# match eager execution
torch._lazy.mark_step() # type: ignore[attr-defined]

return output


@contextlib.contextmanager
def freeze_rng_state():
# no_dispatch needed for test_composite_compliance
# Some OpInfos use freeze_rng_state for rng determinism, but
# test_composite_compliance overrides dispatch for all torch functions
# which we need to disable to get and set rng state
with torch.utils._mode_utils.no_dispatch(), torch._C._DisableFuncTorch():
rng_state = torch.get_rng_state()
if torch.cuda.is_available():
cuda_rng_state = torch.cuda.get_rng_state()
try:
yield
finally:
# Modes are not happy with torch.cuda.set_rng_state
# because it clones the state (which could produce a Tensor Subclass)
# and then grabs the new tensor's data pointer in generator.set_state.
#
# In the long run torch.cuda.set_rng_state should probably be
# an operator.
#
# NB: Mode disable is to avoid running cross-ref tests on thes seeding
with torch.utils._mode_utils.no_dispatch(), torch._C._DisableFuncTorch():
if torch.cuda.is_available():
torch.cuda.set_rng_state(cuda_rng_state) # type: ignore[possibly-undefined]
torch.set_rng_state(rng_state)
Loading