diff --git a/test/distributed/test_data_parallel.py b/test/distributed/test_data_parallel.py index f3161a1f8cb1..c600651a4339 100644 --- a/test/distributed/test_data_parallel.py +++ b/test/distributed/test_data_parallel.py @@ -4,6 +4,7 @@ from copy import deepcopy from collections import OrderedDict from itertools import product +import functools import torch from torch import nn @@ -11,7 +12,7 @@ import torch.nn.parallel as dp from torch.testing._internal.common_cuda import TEST_MULTIGPU, TEST_CUDA from torch.testing._internal.common_utils import run_tests, TestCase, repeat_test_for_types, ALL_TENSORTYPES -from torch.testing._internal.common_utils import _assertGradAndGradgradChecks +from torch.testing._internal.common_utils import _assertGradAndGradgradChecks, gradcheck from torch.testing._internal.common_utils import dtype2prec_DONTUSE from torch.testing._internal.common_utils import skipIfRocm import torch.nn.functional as F @@ -20,6 +21,10 @@ NO_NCCL = not hasattr(torch.distributed, "ProcessGroupNCCL") +# batched grad doesn't support data parallel +gradcheck = functools.partial(gradcheck, check_batched_grad=False) +_assertGradAndGradgradChecks = functools.partial(_assertGradAndGradgradChecks, check_batched_grad=False) + class TestDataParallel(TestCase): @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") @@ -42,7 +47,7 @@ def forward(self, x): def fn(t): return dpm(inp) - torch.autograd.gradcheck(fn, (m.t_rg,)) + gradcheck(fn, (m.t_rg,)) @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") def test_data_parallel_rnn(self): diff --git a/test/distributions/test_distributions.py b/test/distributions/test_distributions.py index 73db638d7c7b..f0bacf7eb282 100644 --- a/test/distributions/test_distributions.py +++ b/test/distributions/test_distributions.py @@ -36,9 +36,10 @@ torch.set_default_dtype(torch.double) from torch._six import inf -from torch.testing._internal.common_utils import TestCase, run_tests, set_rng_seed, TEST_WITH_UBSAN, load_tests +from torch.testing._internal.common_utils import TestCase, run_tests, set_rng_seed, TEST_WITH_UBSAN, load_tests, \ + gradcheck from torch.testing._internal.common_cuda import TEST_CUDA -from torch.autograd import grad, gradcheck +from torch.autograd import grad from torch.autograd.functional import jacobian from torch.distributions import (Bernoulli, Beta, Binomial, Categorical, Cauchy, Chi2, ContinuousBernoulli, Dirichlet, diff --git a/test/test_autograd.py b/test/test_autograd.py index 506298d71a92..4878ba38b3df 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -22,7 +22,6 @@ from torch import nn from torch._six import inf, nan, istuple -from torch.autograd.gradcheck import gradgradcheck, gradcheck from torch.autograd.function import once_differentiable from torch.autograd.profiler import (profile, format_time, EventList, FunctionEvent, FunctionEventAvg, @@ -34,7 +33,8 @@ suppress_warnings, slowTest, load_tests, random_symmetric_matrix, IS_WINDOWS, IS_MACOS, CudaMemoryLeakCheck, - TemporaryFileName, TEST_WITH_ROCM) + TemporaryFileName, TEST_WITH_ROCM, + gradcheck, gradgradcheck) from torch.autograd import Variable, Function, detect_anomaly, kineto_available from torch.autograd.function import InplaceFunction import torch.autograd.forward_ad as fwAD @@ -74,10 +74,6 @@ def getattr_qualified(obj, qname, default=None): PRECISION = 1e-4 -# See #49409, we should remove these if we end up with a global gradcheck setting -gradcheck = partial(gradcheck, check_batched_grad=True) -gradgradcheck = partial(gradgradcheck, check_batched_grad=True) - @contextlib.contextmanager def backward_engine(engine): @@ -2217,10 +2213,10 @@ def backward(ctx, grad_output): x = torch.tensor(2).double().requires_grad_() - self.assertTrue(torch.autograd.gradcheck(double, x)) - self.assertTrue(torch.autograd.gradgradcheck(double, x)) - self.assertTrue(torch.autograd.gradcheck(double2, x)) - self.assertTrue(torch.autograd.gradgradcheck(double2, x)) + self.assertTrue(gradcheck(double, x)) + self.assertTrue(gradgradcheck(double, x)) + self.assertTrue(gradcheck(double2, x)) + self.assertTrue(gradgradcheck(double2, x)) y = double(x) torch.autograd.grad(y, x, create_graph=True) @@ -7065,15 +7061,13 @@ def test_lstmcell_backward_only_one_output_grad(self, device): self.assertFalse(s.grad is None or s.grad.abs().sum().item() == 0) def _test_rnn_mod(self, mod, inp): - from functools import partial - def flatten_out(mod, inp): out = mod(inp) return tuple([t if isinstance(t, torch.Tensor) else tt for t in out for tt in t]) gradcheckfunc = partial(flatten_out, mod) with torch.backends.cudnn.flags(enabled=False): - torch.autograd.gradcheck(gradcheckfunc, inp) - torch.autograd.gradgradcheck(gradcheckfunc, inp) + gradcheck(gradcheckfunc, inp, check_batched_grad=False) + gradgradcheck(gradcheckfunc, inp, check_batched_grad=False) if inp.is_cuda and not TEST_WITH_ROCM: # Assert that we have good error message around unsupported CuDNN double backward diff --git a/test/test_cpp_extensions_jit.py b/test/test_cpp_extensions_jit.py index 0ee99e7b5f68..efda7cb2cf6e 100644 --- a/test/test_cpp_extensions_jit.py +++ b/test/test_cpp_extensions_jit.py @@ -13,7 +13,7 @@ import torch.backends.cudnn import torch.utils.cpp_extension from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME -from torch.autograd.gradcheck import gradcheck +from torch.testing._internal.common_utils import gradcheck TEST_CUDA = torch.cuda.is_available() and CUDA_HOME is not None diff --git a/test/test_linalg.py b/test/test_linalg.py index 851866fcffed..f76731c3a586 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -12,12 +12,12 @@ import random from random import randrange from itertools import product -from functools import reduce, partial +from functools import reduce from torch.testing._internal.common_utils import \ (TestCase, run_tests, TEST_SCIPY, IS_MACOS, IS_WINDOWS, slowTest, TEST_WITH_ASAN, make_tensor, TEST_WITH_ROCM, IS_FBCODE, IS_REMOTE_GPU, - wrapDeterministicFlagAPITest, iter_indices) + wrapDeterministicFlagAPITest, iter_indices, gradcheck, gradgradcheck) from torch.testing._internal.common_device_type import \ (instantiate_device_type_tests, dtypes, onlyCPU, skipCUDAIf, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride, @@ -25,7 +25,6 @@ onlyCUDA) from torch.testing import floating_and_complex_types, floating_types, all_types from torch.testing._internal.common_cuda import SM53OrLater, tf32_on_and_off, CUDA11OrLater, CUDA9 -from torch.autograd import gradcheck, gradgradcheck # Protects against includes accidentally setting the default dtype # NOTE: jit_metaprogramming_utils sets the default dtype to double! @@ -35,10 +34,6 @@ if TEST_SCIPY: import scipy -# See #49409, we should remove these if we end up with a global gradcheck setting -gradcheck = partial(gradcheck, check_batched_grad=True) -gradgradcheck = partial(gradgradcheck, check_batched_grad=True) - class TestLinalg(TestCase): exact_dtype = True @@ -3218,7 +3213,7 @@ def check(equation, *operands): # Check autograd ops = [op.detach().requires_grad_() for op in operands] - self.assertTrue(torch.autograd.gradcheck(lambda *ops: torch.einsum(equation, ops), ops)) + self.assertTrue(gradcheck(lambda *ops: torch.einsum(equation, ops), ops)) for op in ops: self.assertTrue(op._version == 0) diff --git a/test/test_mkldnn.py b/test/test_mkldnn.py index 66187fe2463a..ad2399b69a0b 100644 --- a/test/test_mkldnn.py +++ b/test/test_mkldnn.py @@ -1,5 +1,6 @@ import copy import itertools +import functools import unittest try: @@ -15,9 +16,11 @@ import torch.jit import torch.backends.mkldnn from torch.utils import mkldnn as mkldnn_utils -from torch.testing._internal.common_utils import TestCase, run_tests, TemporaryFileName +from torch.testing._internal.common_utils import TestCase, run_tests, TemporaryFileName, gradcheck, gradgradcheck -from torch.autograd.gradcheck import gradgradcheck, gradcheck +# batched grad doesn't support mkldnn +gradcheck = functools.partial(gradcheck, check_batched_grad=False) +gradgradcheck = functools.partial(gradgradcheck, check_batched_grad=False) types = [torch.float, torch.bfloat16] diff --git a/test/test_nn.py b/test/test_nn.py index 648341c6acd0..e2a2299314c6 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -10,7 +10,7 @@ import pickle from copy import deepcopy from itertools import repeat, product -from functools import reduce, partial +from functools import reduce from operator import mul from collections import OrderedDict @@ -29,8 +29,6 @@ from torch.nn.utils import clip_grad_norm_, clip_grad_value_ import torch.nn.utils.prune as prune from torch.nn.utils import parameters_to_vector, vector_to_parameters -from torch.autograd import gradcheck -from torch.autograd.gradcheck import gradgradcheck from torch.nn import Parameter from torch.nn.parameter import UninitializedParameter from torch.nn.parallel._functions import Broadcast @@ -51,7 +49,7 @@ from hypothesis import given import torch.testing._internal.hypothesis_utils as hu -from torch.testing._internal.common_utils import _assertGradAndGradgradChecks +from torch.testing._internal.common_utils import _assertGradAndGradgradChecks, gradcheck, gradgradcheck from torch.testing._internal.common_utils import dtype2prec_DONTUSE from torch.testing._internal.common_cuda import tf32_on_and_off, tf32_is_not_fp32, tf32_off, tf32_on from torch.types import _TensorOrTensors @@ -72,11 +70,6 @@ DOUBLE_TENSORTYPES = [torch.double] -# See #49409, we should remove these if we end up with a global gradcheck setting -gradcheck = partial(gradcheck, check_batched_grad=True) -gradgradcheck = partial(gradgradcheck, check_batched_grad=True) -_assertGradAndGradgradChecks = partial(_assertGradAndGradgradChecks, check_batched_grad=True) - # WARNING: If you add a new top-level test case to this file, you MUST # update test/run_test.py to list it, otherwise it will NOT be run in @@ -3093,7 +3086,7 @@ def fn(input): out1 = wrapped_m(input) return out0 + out1 - torch.autograd.gradcheck(fn, (input.clone().requires_grad_(),)) + gradcheck(fn, (input.clone().requires_grad_(),), check_batched_grad=False) # test removing pre_remove_out = wrapped_m(input) @@ -3149,14 +3142,14 @@ def fn(input): out3 = wrapped_m(input) return out0 + out1 + out2 + out3 - torch.autograd.gradcheck(fn, (input.clone().requires_grad_(),)) + gradcheck(fn, (input.clone().requires_grad_(),)) # assert that backprop reaches weight_orig in eval if requires_grad: def fn(weight): return wrapped_m(input) - torch.autograd.gradcheck(fn, (m.weight_orig,)) + gradcheck(fn, (m.weight_orig,)) @skipIfNoLapack def test_spectral_norm_load_state_dict(self): diff --git a/test/test_overrides.py b/test/test_overrides.py index b3c8bdd2b718..76cd737417df 100644 --- a/test/test_overrides.py +++ b/test/test_overrides.py @@ -776,7 +776,7 @@ def test_wrapper(self): class TestGradCheckOverride(TestCase): "Test that wrappers work with gradcheck." def test_gradcheck(self): - from torch.autograd import gradcheck, gradgradcheck + from torch.testing._internal.common_utils import gradcheck, gradgradcheck a = wrap(torch.tensor(5.0, dtype=torch.double)) b = wrap(torch.tensor(6.0, dtype=torch.double)) @@ -784,8 +784,8 @@ def test_gradcheck(self): a.requires_grad = True b.requires_grad = True - gradcheck(torch.add, (a, b), raise_exception=False) - gradgradcheck(torch.add, (a, b), raise_exception=False) + gradcheck(torch.add, (a, b), raise_exception=False, check_batched_grad=False) + gradgradcheck(torch.add, (a, b), raise_exception=False, check_batched_grad=False) total_used_attrs = a.used_attrs.union(b.used_attrs) total_used_calls = a.used_calls.union(b.used_calls) diff --git a/test/test_sparse.py b/test/test_sparse.py index 78e3e3de1599..4cdb87329f04 100644 --- a/test/test_sparse.py +++ b/test/test_sparse.py @@ -11,10 +11,9 @@ from collections import defaultdict import unittest from torch.testing._internal.common_utils import TestCase, run_tests, skipIfRocm, do_test_dtypes, \ - do_test_empty_full, load_tests, TEST_NUMPY, TEST_SCIPY, IS_WINDOWS + do_test_empty_full, load_tests, TEST_NUMPY, TEST_SCIPY, IS_WINDOWS, gradcheck from torch.testing._internal.common_cuda import TEST_CUDA, _get_torch_cuda_version from numbers import Number -from torch.autograd.gradcheck import gradcheck from typing import Dict, Any from torch.testing._internal.common_device_type import \ (instantiate_device_type_tests, ops) @@ -28,6 +27,8 @@ # sharding on sandcastle. This line silences flake warnings load_tests = load_tests +# batched grad doesn't support sparse +gradcheck = functools.partial(gradcheck, check_batched_grad=False) def cpu_only(inner): @functools.wraps(inner) diff --git a/test/test_unary_ufuncs.py b/test/test_unary_ufuncs.py index 3497ccd04cc1..97efacfbad86 100644 --- a/test/test_unary_ufuncs.py +++ b/test/test_unary_ufuncs.py @@ -11,7 +11,8 @@ from torch._six import inf, nan from torch.testing._internal.common_utils import ( TestCase, run_tests, torch_to_numpy_dtype_dict, numpy_to_torch_dtype_dict, - suppress_warnings, IS_MACOS, make_tensor, TEST_SCIPY, slowTest, skipIfNoSciPy) + suppress_warnings, IS_MACOS, make_tensor, TEST_SCIPY, slowTest, skipIfNoSciPy, + gradcheck) from torch.testing._internal.common_methods_invocations import ( unary_ufuncs) from torch.testing._internal.common_device_type import ( @@ -1113,8 +1114,7 @@ def test_polygamma(self, device, dtype): cpu_tensor.requires_grad = True for n in [0, 1, 2, 3, 4, 5]: - torch.autograd.gradcheck(lambda x: x.polygamma(n), cpu_tensor, - check_batched_grad=True) + gradcheck(lambda x: x.polygamma(n), cpu_tensor) # TODO: update to compare against NumPy by rationalizing with OpInfo @onlyCUDA diff --git a/torch/testing/_internal/common_nn.py b/torch/testing/_internal/common_nn.py index c3250c69656b..79c581d04b05 100644 --- a/torch/testing/_internal/common_nn.py +++ b/torch/testing/_internal/common_nn.py @@ -16,11 +16,10 @@ import torch.nn.functional as F from torch.nn import _reduction as _Reduction from torch.testing._internal.common_utils import TestCase, to_gpu, freeze_rng_state, is_iterable, \ - TEST_WITH_ROCM + TEST_WITH_ROCM, gradcheck, gradgradcheck from torch.testing._internal.common_cuda import TEST_CUDA from torch.testing._internal.common_device_type import expectedAlertNondeterministic -from torch.autograd.gradcheck import get_numerical_jacobian, iter_tensors, \ - gradcheck, gradgradcheck +from torch.autograd.gradcheck import get_numerical_jacobian, iter_tensors from torch.autograd import Variable from torch.types import _TensorOrTensors import torch.backends.cudnn diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 2a651865286b..2479d5729e0a 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -52,8 +52,6 @@ import torch.backends.cudnn import torch.backends.mkl from enum import Enum -from torch.autograd import gradcheck -from torch.autograd.gradcheck import gradgradcheck torch.backends.disable_global_flags() @@ -1894,11 +1892,42 @@ def __enter__(self): def __exit__(self, *args): pass -def _assertGradAndGradgradChecks(test_case, apply_fn, inputs, check_batched_grad=False): + +def gradcheck(fn, inputs, **kwargs): + # Wrapper around gradcheck that enables certain keys by default. + # Use this testing-internal gradcheck instead of autograd.gradcheck so that new features like vmap and + # forward-mode AD are tested by default. We create this wrapper because we'd like to keep new checks + # to be disabled to default for the public-facing api to avoid breaking user code. + # + # All PyTorch devs doing testing should use this wrapper instead of autograd.gradcheck. + keys_enabled_by_default = ( + "check_batched_grad",) + + for key in keys_enabled_by_default: + kwargs[key] = kwargs.get(key, True) + + return torch.autograd.gradcheck(fn, inputs, **kwargs) + + +def gradgradcheck(fn, inputs, grad_outputs=None, **kwargs): + # Wrapper around gradgradcheck that enables certain keys by default + # See gradcheck above for an explanation of why we need something like this. + # + # All PyTorch devs doing testing should use this wrapper instead of autograd.gradgradcheck + keys_enabled_by_default = ( + "check_batched_grad",) + + for key in keys_enabled_by_default: + kwargs[key] = kwargs.get(key, True) + + return torch.autograd.gradgradcheck(fn, inputs, grad_outputs, **kwargs) + + +def _assertGradAndGradgradChecks(test_case, apply_fn, inputs, **kwargs): # call assert function rather than returning a bool since it's nicer # if we get whether this failed on the gradcheck or the gradgradcheck. - test_case.assertTrue(gradcheck(apply_fn, inputs, check_batched_grad=check_batched_grad)) - test_case.assertTrue(gradgradcheck(apply_fn, inputs, check_batched_grad=check_batched_grad)) + test_case.assertTrue(gradcheck(apply_fn, inputs, **kwargs)) + test_case.assertTrue(gradgradcheck(apply_fn, inputs, **kwargs)) @contextmanager