Skip to content

Commit

Permalink
Internal gradcheck wrapper in testing._internal that sets certain fla…
Browse files Browse the repository at this point in the history
…gs to True (#51133)

Summary:
Fixes #49409

There are many call sites where, gradcheck/gradgradcheck is now being implicitly invoked with `check_batched_grad` as True, but they were previously False. Cases fall into two basic categories:
1) the call site was previously using `torch.autograd.gradcheck` but is now changed to use the globally imported function instead
3) the call site was already using globally imported function, but does not explicitly pass `check_batched_grad` flag

Only in the _assertGradAndGradgradChecks cases, which are infrequent, I assumed that the the author is aware that omitting the flag means not applying check_batched_grad=True. (but maybe that is not the case?)

Overall this PR in its current state assumes that unless the author explicitly specified `check_batched_grad=False`, they were just probably not aware of this flag and did not mean to have this flag as False.

So far exceptions to the above (as discovered by CI) include:
 - Mkldnn (opaque tensors do not have strides) https://app.circleci.com/pipelines/github/pytorch/pytorch/264416/workflows/e4d87886-6247-4305-8526-2696130aa9a4/jobs/10401882/tests
 - all cases in test_sparse (https://app.circleci.com/pipelines/github/pytorch/pytorch/264553/workflows/3c1cbe30-830d-4acd-b240-38d833dccd9b/jobs/10407103)
 - all cases in test_overrides (https://app.circleci.com/pipelines/github/pytorch/pytorch/264553/workflows/3c1cbe30-830d-4acd-b240-38d833dccd9b/jobs/10407236)
 - test_autograd (test_LSTM_grad_and_gradgrad) - (https://app.circleci.com/pipelines/github/pytorch/pytorch/264553/workflows/3c1cbe30-830d-4acd-b240-38d833dccd9b/jobs/10407235)
 - test_data_parallel (test_data_parallel_buffers_requiring_grad) - *SIGSEGV* (https://app.circleci.com/pipelines/github/pytorch/pytorch/264820/workflows/14d89503-040d-4e3d-9f7b-0bc04833589b/jobs/10422697)
 - test_nn (https://app.circleci.com/pipelines/github/pytorch/pytorch/264919/workflows/df79e3ed-8a31-4a8e-b584-858ee99686ff/jobs/10427315)

Possible TODO is to prevent new tests from invoking external gradcheck.

Pull Request resolved: #51133

Reviewed By: ezyang

Differential Revision: D26147919

Pulled By: soulitzer

fbshipit-source-id: dff883b50f337510a89f391ea2fd87de2d531432
  • Loading branch information
soulitzer authored and facebook-github-bot committed Jan 29, 2021
1 parent 5a406c0 commit c096691
Show file tree
Hide file tree
Showing 12 changed files with 77 additions and 57 deletions.
9 changes: 7 additions & 2 deletions test/distributed/test_data_parallel.py
Expand Up @@ -4,14 +4,15 @@
from copy import deepcopy
from collections import OrderedDict
from itertools import product
import functools

import torch
from torch import nn
from torch.cuda.amp import autocast
import torch.nn.parallel as dp
from torch.testing._internal.common_cuda import TEST_MULTIGPU, TEST_CUDA
from torch.testing._internal.common_utils import run_tests, TestCase, repeat_test_for_types, ALL_TENSORTYPES
from torch.testing._internal.common_utils import _assertGradAndGradgradChecks
from torch.testing._internal.common_utils import _assertGradAndGradgradChecks, gradcheck
from torch.testing._internal.common_utils import dtype2prec_DONTUSE
from torch.testing._internal.common_utils import skipIfRocm
import torch.nn.functional as F
Expand All @@ -20,6 +21,10 @@

NO_NCCL = not hasattr(torch.distributed, "ProcessGroupNCCL")

# batched grad doesn't support data parallel
gradcheck = functools.partial(gradcheck, check_batched_grad=False)
_assertGradAndGradgradChecks = functools.partial(_assertGradAndGradgradChecks, check_batched_grad=False)

class TestDataParallel(TestCase):

@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
Expand All @@ -42,7 +47,7 @@ def forward(self, x):
def fn(t):
return dpm(inp)

torch.autograd.gradcheck(fn, (m.t_rg,))
gradcheck(fn, (m.t_rg,))

@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
def test_data_parallel_rnn(self):
Expand Down
5 changes: 3 additions & 2 deletions test/distributions/test_distributions.py
Expand Up @@ -36,9 +36,10 @@
torch.set_default_dtype(torch.double)

from torch._six import inf
from torch.testing._internal.common_utils import TestCase, run_tests, set_rng_seed, TEST_WITH_UBSAN, load_tests
from torch.testing._internal.common_utils import TestCase, run_tests, set_rng_seed, TEST_WITH_UBSAN, load_tests, \
gradcheck
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.autograd import grad, gradcheck
from torch.autograd import grad
from torch.autograd.functional import jacobian
from torch.distributions import (Bernoulli, Beta, Binomial, Categorical,
Cauchy, Chi2, ContinuousBernoulli, Dirichlet,
Expand Down
22 changes: 8 additions & 14 deletions test/test_autograd.py
Expand Up @@ -22,7 +22,6 @@

from torch import nn
from torch._six import inf, nan, istuple
from torch.autograd.gradcheck import gradgradcheck, gradcheck
from torch.autograd.function import once_differentiable
from torch.autograd.profiler import (profile, format_time, EventList,
FunctionEvent, FunctionEventAvg,
Expand All @@ -34,7 +33,8 @@
suppress_warnings, slowTest,
load_tests, random_symmetric_matrix,
IS_WINDOWS, IS_MACOS, CudaMemoryLeakCheck,
TemporaryFileName, TEST_WITH_ROCM)
TemporaryFileName, TEST_WITH_ROCM,
gradcheck, gradgradcheck)
from torch.autograd import Variable, Function, detect_anomaly, kineto_available
from torch.autograd.function import InplaceFunction
import torch.autograd.forward_ad as fwAD
Expand Down Expand Up @@ -74,10 +74,6 @@ def getattr_qualified(obj, qname, default=None):

PRECISION = 1e-4

# See #49409, we should remove these if we end up with a global gradcheck setting
gradcheck = partial(gradcheck, check_batched_grad=True)
gradgradcheck = partial(gradgradcheck, check_batched_grad=True)


@contextlib.contextmanager
def backward_engine(engine):
Expand Down Expand Up @@ -2217,10 +2213,10 @@ def backward(ctx, grad_output):

x = torch.tensor(2).double().requires_grad_()

self.assertTrue(torch.autograd.gradcheck(double, x))
self.assertTrue(torch.autograd.gradgradcheck(double, x))
self.assertTrue(torch.autograd.gradcheck(double2, x))
self.assertTrue(torch.autograd.gradgradcheck(double2, x))
self.assertTrue(gradcheck(double, x))
self.assertTrue(gradgradcheck(double, x))
self.assertTrue(gradcheck(double2, x))
self.assertTrue(gradgradcheck(double2, x))

y = double(x)
torch.autograd.grad(y, x, create_graph=True)
Expand Down Expand Up @@ -7065,15 +7061,13 @@ def test_lstmcell_backward_only_one_output_grad(self, device):
self.assertFalse(s.grad is None or s.grad.abs().sum().item() == 0)

def _test_rnn_mod(self, mod, inp):
from functools import partial

def flatten_out(mod, inp):
out = mod(inp)
return tuple([t if isinstance(t, torch.Tensor) else tt for t in out for tt in t])
gradcheckfunc = partial(flatten_out, mod)
with torch.backends.cudnn.flags(enabled=False):
torch.autograd.gradcheck(gradcheckfunc, inp)
torch.autograd.gradgradcheck(gradcheckfunc, inp)
gradcheck(gradcheckfunc, inp, check_batched_grad=False)
gradgradcheck(gradcheckfunc, inp, check_batched_grad=False)

if inp.is_cuda and not TEST_WITH_ROCM:
# Assert that we have good error message around unsupported CuDNN double backward
Expand Down
2 changes: 1 addition & 1 deletion test/test_cpp_extensions_jit.py
Expand Up @@ -13,7 +13,7 @@
import torch.backends.cudnn
import torch.utils.cpp_extension
from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME
from torch.autograd.gradcheck import gradcheck
from torch.testing._internal.common_utils import gradcheck


TEST_CUDA = torch.cuda.is_available() and CUDA_HOME is not None
Expand Down
11 changes: 3 additions & 8 deletions test/test_linalg.py
Expand Up @@ -12,20 +12,19 @@
import random
from random import randrange
from itertools import product
from functools import reduce, partial
from functools import reduce

from torch.testing._internal.common_utils import \
(TestCase, run_tests, TEST_SCIPY, IS_MACOS, IS_WINDOWS, slowTest,
TEST_WITH_ASAN, make_tensor, TEST_WITH_ROCM, IS_FBCODE, IS_REMOTE_GPU,
wrapDeterministicFlagAPITest, iter_indices)
wrapDeterministicFlagAPITest, iter_indices, gradcheck, gradgradcheck)
from torch.testing._internal.common_device_type import \
(instantiate_device_type_tests, dtypes,
onlyCPU, skipCUDAIf, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride,
skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, onlyOnCPUAndCUDA, dtypesIfCUDA,
onlyCUDA)
from torch.testing import floating_and_complex_types, floating_types, all_types
from torch.testing._internal.common_cuda import SM53OrLater, tf32_on_and_off, CUDA11OrLater, CUDA9
from torch.autograd import gradcheck, gradgradcheck

# Protects against includes accidentally setting the default dtype
# NOTE: jit_metaprogramming_utils sets the default dtype to double!
Expand All @@ -35,10 +34,6 @@
if TEST_SCIPY:
import scipy

# See #49409, we should remove these if we end up with a global gradcheck setting
gradcheck = partial(gradcheck, check_batched_grad=True)
gradgradcheck = partial(gradgradcheck, check_batched_grad=True)

class TestLinalg(TestCase):
exact_dtype = True

Expand Down Expand Up @@ -3218,7 +3213,7 @@ def check(equation, *operands):

# Check autograd
ops = [op.detach().requires_grad_() for op in operands]
self.assertTrue(torch.autograd.gradcheck(lambda *ops: torch.einsum(equation, ops), ops))
self.assertTrue(gradcheck(lambda *ops: torch.einsum(equation, ops), ops))
for op in ops:
self.assertTrue(op._version == 0)

Expand Down
7 changes: 5 additions & 2 deletions test/test_mkldnn.py
@@ -1,5 +1,6 @@
import copy
import itertools
import functools
import unittest

try:
Expand All @@ -15,9 +16,11 @@
import torch.jit
import torch.backends.mkldnn
from torch.utils import mkldnn as mkldnn_utils
from torch.testing._internal.common_utils import TestCase, run_tests, TemporaryFileName
from torch.testing._internal.common_utils import TestCase, run_tests, TemporaryFileName, gradcheck, gradgradcheck

from torch.autograd.gradcheck import gradgradcheck, gradcheck
# batched grad doesn't support mkldnn
gradcheck = functools.partial(gradcheck, check_batched_grad=False)
gradgradcheck = functools.partial(gradgradcheck, check_batched_grad=False)

types = [torch.float, torch.bfloat16]

Expand Down
17 changes: 5 additions & 12 deletions test/test_nn.py
Expand Up @@ -10,7 +10,7 @@
import pickle
from copy import deepcopy
from itertools import repeat, product
from functools import reduce, partial
from functools import reduce
from operator import mul
from collections import OrderedDict

Expand All @@ -29,8 +29,6 @@
from torch.nn.utils import clip_grad_norm_, clip_grad_value_
import torch.nn.utils.prune as prune
from torch.nn.utils import parameters_to_vector, vector_to_parameters
from torch.autograd import gradcheck
from torch.autograd.gradcheck import gradgradcheck
from torch.nn import Parameter
from torch.nn.parameter import UninitializedParameter
from torch.nn.parallel._functions import Broadcast
Expand All @@ -51,7 +49,7 @@

from hypothesis import given
import torch.testing._internal.hypothesis_utils as hu
from torch.testing._internal.common_utils import _assertGradAndGradgradChecks
from torch.testing._internal.common_utils import _assertGradAndGradgradChecks, gradcheck, gradgradcheck
from torch.testing._internal.common_utils import dtype2prec_DONTUSE
from torch.testing._internal.common_cuda import tf32_on_and_off, tf32_is_not_fp32, tf32_off, tf32_on
from torch.types import _TensorOrTensors
Expand All @@ -72,11 +70,6 @@

DOUBLE_TENSORTYPES = [torch.double]

# See #49409, we should remove these if we end up with a global gradcheck setting
gradcheck = partial(gradcheck, check_batched_grad=True)
gradgradcheck = partial(gradgradcheck, check_batched_grad=True)
_assertGradAndGradgradChecks = partial(_assertGradAndGradgradChecks, check_batched_grad=True)


# WARNING: If you add a new top-level test case to this file, you MUST
# update test/run_test.py to list it, otherwise it will NOT be run in
Expand Down Expand Up @@ -3093,7 +3086,7 @@ def fn(input):
out1 = wrapped_m(input)
return out0 + out1

torch.autograd.gradcheck(fn, (input.clone().requires_grad_(),))
gradcheck(fn, (input.clone().requires_grad_(),), check_batched_grad=False)

# test removing
pre_remove_out = wrapped_m(input)
Expand Down Expand Up @@ -3149,14 +3142,14 @@ def fn(input):
out3 = wrapped_m(input)
return out0 + out1 + out2 + out3

torch.autograd.gradcheck(fn, (input.clone().requires_grad_(),))
gradcheck(fn, (input.clone().requires_grad_(),))

# assert that backprop reaches weight_orig in eval
if requires_grad:
def fn(weight):
return wrapped_m(input)

torch.autograd.gradcheck(fn, (m.weight_orig,))
gradcheck(fn, (m.weight_orig,))

@skipIfNoLapack
def test_spectral_norm_load_state_dict(self):
Expand Down
6 changes: 3 additions & 3 deletions test/test_overrides.py
Expand Up @@ -776,16 +776,16 @@ def test_wrapper(self):
class TestGradCheckOverride(TestCase):
"Test that wrappers work with gradcheck."
def test_gradcheck(self):
from torch.autograd import gradcheck, gradgradcheck
from torch.testing._internal.common_utils import gradcheck, gradgradcheck

a = wrap(torch.tensor(5.0, dtype=torch.double))
b = wrap(torch.tensor(6.0, dtype=torch.double))

a.requires_grad = True
b.requires_grad = True

gradcheck(torch.add, (a, b), raise_exception=False)
gradgradcheck(torch.add, (a, b), raise_exception=False)
gradcheck(torch.add, (a, b), raise_exception=False, check_batched_grad=False)
gradgradcheck(torch.add, (a, b), raise_exception=False, check_batched_grad=False)

total_used_attrs = a.used_attrs.union(b.used_attrs)
total_used_calls = a.used_calls.union(b.used_calls)
Expand Down
5 changes: 3 additions & 2 deletions test/test_sparse.py
Expand Up @@ -11,10 +11,9 @@
from collections import defaultdict
import unittest
from torch.testing._internal.common_utils import TestCase, run_tests, skipIfRocm, do_test_dtypes, \
do_test_empty_full, load_tests, TEST_NUMPY, TEST_SCIPY, IS_WINDOWS
do_test_empty_full, load_tests, TEST_NUMPY, TEST_SCIPY, IS_WINDOWS, gradcheck
from torch.testing._internal.common_cuda import TEST_CUDA, _get_torch_cuda_version
from numbers import Number
from torch.autograd.gradcheck import gradcheck
from typing import Dict, Any
from torch.testing._internal.common_device_type import \
(instantiate_device_type_tests, ops)
Expand All @@ -28,6 +27,8 @@
# sharding on sandcastle. This line silences flake warnings
load_tests = load_tests

# batched grad doesn't support sparse
gradcheck = functools.partial(gradcheck, check_batched_grad=False)

def cpu_only(inner):
@functools.wraps(inner)
Expand Down
6 changes: 3 additions & 3 deletions test/test_unary_ufuncs.py
Expand Up @@ -11,7 +11,8 @@
from torch._six import inf, nan
from torch.testing._internal.common_utils import (
TestCase, run_tests, torch_to_numpy_dtype_dict, numpy_to_torch_dtype_dict,
suppress_warnings, IS_MACOS, make_tensor, TEST_SCIPY, slowTest, skipIfNoSciPy)
suppress_warnings, IS_MACOS, make_tensor, TEST_SCIPY, slowTest, skipIfNoSciPy,
gradcheck)
from torch.testing._internal.common_methods_invocations import (
unary_ufuncs)
from torch.testing._internal.common_device_type import (
Expand Down Expand Up @@ -1113,8 +1114,7 @@ def test_polygamma(self, device, dtype):

cpu_tensor.requires_grad = True
for n in [0, 1, 2, 3, 4, 5]:
torch.autograd.gradcheck(lambda x: x.polygamma(n), cpu_tensor,
check_batched_grad=True)
gradcheck(lambda x: x.polygamma(n), cpu_tensor)

# TODO: update to compare against NumPy by rationalizing with OpInfo
@onlyCUDA
Expand Down
5 changes: 2 additions & 3 deletions torch/testing/_internal/common_nn.py
Expand Up @@ -16,11 +16,10 @@
import torch.nn.functional as F
from torch.nn import _reduction as _Reduction
from torch.testing._internal.common_utils import TestCase, to_gpu, freeze_rng_state, is_iterable, \
TEST_WITH_ROCM
TEST_WITH_ROCM, gradcheck, gradgradcheck
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_device_type import expectedAlertNondeterministic
from torch.autograd.gradcheck import get_numerical_jacobian, iter_tensors, \
gradcheck, gradgradcheck
from torch.autograd.gradcheck import get_numerical_jacobian, iter_tensors
from torch.autograd import Variable
from torch.types import _TensorOrTensors
import torch.backends.cudnn
Expand Down
39 changes: 34 additions & 5 deletions torch/testing/_internal/common_utils.py
Expand Up @@ -52,8 +52,6 @@
import torch.backends.cudnn
import torch.backends.mkl
from enum import Enum
from torch.autograd import gradcheck
from torch.autograd.gradcheck import gradgradcheck

torch.backends.disable_global_flags()

Expand Down Expand Up @@ -1894,11 +1892,42 @@ def __enter__(self):
def __exit__(self, *args):
pass

def _assertGradAndGradgradChecks(test_case, apply_fn, inputs, check_batched_grad=False):

def gradcheck(fn, inputs, **kwargs):
# Wrapper around gradcheck that enables certain keys by default.
# Use this testing-internal gradcheck instead of autograd.gradcheck so that new features like vmap and
# forward-mode AD are tested by default. We create this wrapper because we'd like to keep new checks
# to be disabled to default for the public-facing api to avoid breaking user code.
#
# All PyTorch devs doing testing should use this wrapper instead of autograd.gradcheck.
keys_enabled_by_default = (
"check_batched_grad",)

for key in keys_enabled_by_default:
kwargs[key] = kwargs.get(key, True)

return torch.autograd.gradcheck(fn, inputs, **kwargs)


def gradgradcheck(fn, inputs, grad_outputs=None, **kwargs):
# Wrapper around gradgradcheck that enables certain keys by default
# See gradcheck above for an explanation of why we need something like this.
#
# All PyTorch devs doing testing should use this wrapper instead of autograd.gradgradcheck
keys_enabled_by_default = (
"check_batched_grad",)

for key in keys_enabled_by_default:
kwargs[key] = kwargs.get(key, True)

return torch.autograd.gradgradcheck(fn, inputs, grad_outputs, **kwargs)


def _assertGradAndGradgradChecks(test_case, apply_fn, inputs, **kwargs):
# call assert function rather than returning a bool since it's nicer
# if we get whether this failed on the gradcheck or the gradgradcheck.
test_case.assertTrue(gradcheck(apply_fn, inputs, check_batched_grad=check_batched_grad))
test_case.assertTrue(gradgradcheck(apply_fn, inputs, check_batched_grad=check_batched_grad))
test_case.assertTrue(gradcheck(apply_fn, inputs, **kwargs))
test_case.assertTrue(gradgradcheck(apply_fn, inputs, **kwargs))


@contextmanager
Expand Down

0 comments on commit c096691

Please sign in to comment.