Skip to content

Commit

Permalink
ENH: toggle TORCH_WARN_ONCE to TORCH_WARN for tests (#48560)
Browse files Browse the repository at this point in the history
Summary:
Toward fixing #47624

~Step 1: add `TORCH_WARN_MAYBE` which can either warn once or every time in c++, and add a c++ function to toggle the value.
Step 2 will be to expose this to python for tests. Should I continue in this PR or should we take a different approach: add the python level exposure without changing any c++ code and then over a series of PRs change each call site to use the new macro and change the tests to make sure it is being checked?~

Step 1: add a python and c++ toggle to convert TORCH_WARN_ONCE into TORCH_WARN so the warnings can be caught in tests
Step 2: add a python-level decorator to use this toggle in tests
Step 3: (in future PRs): use the decorator to catch the warnings instead of `maybeWarnsRegex`

Pull Request resolved: #48560

Reviewed By: ngimel

Differential Revision: D26171175

Pulled By: mruberry

fbshipit-source-id: d83c18f131d282474a24c50f70a6eee82687158f
  • Loading branch information
mattip authored and facebook-github-bot committed Feb 8, 2021
1 parent d454a84 commit b97a040
Show file tree
Hide file tree
Showing 9 changed files with 127 additions and 9 deletions.
10 changes: 10 additions & 0 deletions c10/util/Exception.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,16 @@ WarningHandler* get_warning_handler() noexcept(true) {
return ThreadWarningHandler::get_handler();
}

bool warn_always = false;

void set_warnAlways(bool setting) noexcept(true) {
warn_always = setting;
}

bool get_warnAlways() noexcept(true) {
return warn_always;
}

} // namespace Warning

void WarningHandler::process(
Expand Down
17 changes: 15 additions & 2 deletions c10/util/Exception.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,12 @@ C10_API void set_warning_handler(WarningHandler* handler) noexcept(true);
/// Gets the global warning handler.
C10_API WarningHandler* get_warning_handler() noexcept(true);

/// The TORCH_WARN_ONCE macro is difficult to test for. Use
/// setWarnAlways(true) to turn it into TORCH_WARN, which can be
/// tested for more easily.
C10_API void set_warnAlways(bool) noexcept(true);
C10_API bool get_warnAlways(void) noexcept(true);

} // namespace Warning

// Used in ATen for out-of-bound indices that can reasonably only be detected
Expand Down Expand Up @@ -418,19 +424,26 @@ namespace detail {
// arguments which are concatenated into the warning message using operator<<
//
#ifdef STRIP_ERROR_MESSAGES
#define TORCH_WARN_ONCE(...) \
#define _TORCH_WARN_ONCE(...) \
C10_UNUSED static const auto C10_ANONYMOUS_VARIABLE(torch_warn_once_) = [&] { \
::c10::Warning::warn({__func__, __FILE__, static_cast<uint32_t>(__LINE__)}, {}, false); \
return true; \
}()
#else
#define TORCH_WARN_ONCE(...) \
#define _TORCH_WARN_ONCE(...) \
C10_UNUSED static const auto C10_ANONYMOUS_VARIABLE(torch_warn_once_) = [&] { \
::c10::Warning::warn({__func__, __FILE__, static_cast<uint32_t>(__LINE__)}, ::c10::str(__VA_ARGS__), false); \
return true; \
}()
#endif

#define TORCH_WARN_ONCE(...) \
if (::c10::Warning::get_warnAlways()) { \
TORCH_WARN(__VA_ARGS__); \
} else { \
_TORCH_WARN_ONCE(__VA_ARGS__); \
}

// ----------------------------------------------------------------------------
// Deprecated macros
// ----------------------------------------------------------------------------
Expand Down
2 changes: 2 additions & 0 deletions docs/source/torch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -563,5 +563,7 @@ Utilities
promote_types
use_deterministic_algorithms
are_deterministic_algorithms_enabled
set_warn_always
is_warn_always_enabled
vmap
_assert
21 changes: 21 additions & 0 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3295,6 +3295,27 @@ def warn_fn():
self.assertEqual(frameinfo.lineno - 6, warning.lineno)
self.assertEqual(len(w), 1)

@onlyCPU
def test_warn_always_caught(self, device):
# Check that we can catch a TORCH_WARN_ONCE warning twice
# since assertWarnsOnceRegex uses set_warn_always(True) which changes
# TORCH_WARN_ONCE to TORCH_WARN
a = np.arange(10)
a.flags.writeable = False
with self.assertWarnsOnceRegex(UserWarning, '.*non-writeable.*'):
torch.from_numpy(a)

# OK, got it once, now try again
with self.assertWarnsOnceRegex(UserWarning, '.*non-writeable.*'):
torch.from_numpy(a)

# Make sure emitting two warnings, even if they pass the regex, will fail
# the assertWarnsOnceRegex context manager which only allows a single warning
with self.assertRaisesRegex(AssertionError, '.*too many.*non-writeable.*'):
with self.assertWarnsOnceRegex(UserWarning, '.*non-writeable.*'):
torch.from_numpy(a)
torch.from_numpy(a)

# TODO: this test should be in test_nn.py
def test_conv_transposed_backward_agnostic_to_memory_format(self, device):
in_channels = 64
Expand Down
2 changes: 2 additions & 0 deletions torch/_C/__init__.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,8 @@ def _get_cudnn_deterministic() -> _bool: ... # THPModule_deterministicCuDNN
def _set_cudnn_deterministic(arg: _bool) -> None: ... # THPModule_setDeterministicCuDNN
def _get_deterministic_algorithms() -> _bool: ... # THPModule_deterministicAlgorithms
def _set_deterministic_algorithms(arg: _bool) -> None: ... # THPModule_setDeterministicAlgorithms
def _get_warnAlways() -> _bool: ... # THPModule_warnAlways
def _set_warnAlways(arg: _bool) -> None: ... # THPModule_setWarnAlways
def _get_cudnn_allow_tf32() -> _bool: ... # THPModule_allowTF32CuDNN
def _set_cudnn_allow_tf32(arg: _bool) -> None: ... # THPModule_setAllowTF32CuDNN
def _get_cublas_allow_tf32() -> _bool: ... # THPModule_allowTF32CuBLAS
Expand Down
21 changes: 20 additions & 1 deletion torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@
'DoubleTensor', 'FloatTensor', 'LongTensor', 'IntTensor',
'ShortTensor', 'CharTensor', 'ByteTensor', 'BoolTensor', 'Tensor',
'lobpcg', 'use_deterministic_algorithms', 'set_deterministic',
'are_deterministic_algorithms_enabled', 'is_deterministic'
'are_deterministic_algorithms_enabled', 'is_deterministic',
'set_warn_always', 'is_warn_always_enabled',
]

################################################################################
Expand Down Expand Up @@ -440,6 +441,24 @@ def is_deterministic():
return are_deterministic_algorithms_enabled()


def set_warn_always(b):
r"""When this flag is False (default) then some PyTorch warnings may only
appear once per process. This helps avoid excessive warning information.
Setting it to True causes these warnings to always appear, which may be
helpful when debugging.
Args:
b (:class:`bool`): If True, force warnings to always be emitted
If False, set to the default behaviour
"""
_C._set_warnAlways(b)

def is_warn_always_enabled():
r"""Returns True if the global warn_always flag is turned on. Refer to
:func:`torch.set_warn_always` documentation for more details.
"""
return _C._get_warnAlways()

################################################################################
# Define Storage and Tensor classes
################################################################################
Expand Down
36 changes: 30 additions & 6 deletions torch/csrc/Module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -469,8 +469,26 @@ PyObject *THPModule_setDeterministicAlgorithms(PyObject *_unused, PyObject *arg)

PyObject *THPModule_deterministicAlgorithms(PyObject *_unused, PyObject *noargs)
{
if (at::globalContext().deterministicAlgorithms()) Py_RETURN_TRUE;
else Py_RETURN_FALSE;
if (at::globalContext().deterministicAlgorithms()) {
Py_RETURN_TRUE;
}
Py_RETURN_FALSE;
}

PyObject *THPModule_setWarnAlways(PyObject *_unused, PyObject *arg)
{
THPUtils_assert(PyBool_Check(arg), "setWarnOnlyOnce expects a bool, "
"but got %s", THPUtils_typename(arg));
c10::Warning::set_warnAlways(arg == Py_True);
Py_RETURN_NONE;
}

PyObject *THPModule_warnAlways(PyObject *_unused, PyObject *noargs)
{
if (c10::Warning::get_warnAlways()) {
Py_RETURN_TRUE;
}
Py_RETURN_FALSE;
}

PyObject *THPModule_setBenchmarkCuDNN(PyObject *_unused, PyObject *arg)
Expand All @@ -489,8 +507,10 @@ PyObject *THPModule_setBenchmarkCuDNN(PyObject *_unused, PyObject *arg)

PyObject *THPModule_benchmarkCuDNN(PyObject *_unused, PyObject *noargs)
{
if (at::globalContext().benchmarkCuDNN()) Py_RETURN_TRUE;
else Py_RETURN_FALSE;
if (at::globalContext().benchmarkCuDNN()) {
Py_RETURN_TRUE;
}
Py_RETURN_FALSE;
}

PyObject *THPModule_setAllowTF32CuBLAS(PyObject *_unused, PyObject *arg)
Expand All @@ -503,8 +523,10 @@ PyObject *THPModule_setAllowTF32CuBLAS(PyObject *_unused, PyObject *arg)

PyObject *THPModule_allowTF32CuBLAS(PyObject *_unused, PyObject *noargs)
{
if (at::globalContext().allowTF32CuBLAS()) Py_RETURN_TRUE;
else Py_RETURN_FALSE;
if (at::globalContext().allowTF32CuBLAS()) {
Py_RETURN_TRUE;
}
Py_RETURN_FALSE;
}

PyObject *THPModule_setFlushDenormal(PyObject *_unused, PyObject *arg) {
Expand Down Expand Up @@ -634,6 +656,8 @@ static PyMethodDef TorchMethods[] = {
{"_set_cudnn_deterministic", THPModule_setDeterministicCuDNN, METH_O, nullptr},
{"_get_deterministic_algorithms", THPModule_deterministicAlgorithms, METH_NOARGS, nullptr},
{"_set_deterministic_algorithms", THPModule_setDeterministicAlgorithms, METH_O, nullptr},
{"_get_warnAlways", THPModule_warnAlways, METH_NOARGS, nullptr},
{"_set_warnAlways", THPModule_setWarnAlways, METH_O, nullptr},
{"_get_cublas_allow_tf32", THPModule_allowTF32CuBLAS, METH_NOARGS, nullptr},
{"_set_cublas_allow_tf32", THPModule_setAllowTF32CuBLAS, METH_O, nullptr},
{"_vmapmode_increment_nesting", THPModule_vmapmode_increment_nesting, METH_NOARGS, nullptr},
Expand Down
2 changes: 2 additions & 0 deletions torch/overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,8 @@ def get_ignored_functions() -> Set[Callable]:
torch.use_deterministic_algorithms,
torch.set_deterministic,
torch.unify_type_list,
torch.is_warn_always_enabled,
torch.set_warn_always,
Tensor.__delitem__,
Tensor.__dir__,
Tensor.__getattribute__,
Expand Down
25 changes: 25 additions & 0 deletions torch/testing/_internal/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1287,6 +1287,31 @@ def maybeWarnsRegex(self, category, regex=''):
msg += '\n'
self.fail(msg)

@contextmanager
def assertWarnsOnceRegex(self, category, regex=''):
"""Context manager for code that *must always* warn
This filters expected warnings from the test and fails if
the expected warning is not caught. It uses set_warn_always() to force
TORCH_WARN_ONCE to behave like TORCH_WARN
"""
pattern = re.compile(regex)
with warnings.catch_warnings(record=True) as ws:
warnings.simplefilter("always") # allow any warning to be raised
prev = torch.is_warn_always_enabled()
torch.set_warn_always(True)
try:
yield
finally:
torch.set_warn_always(prev)
if len(ws) == 0:
self.fail('no warning caught')
if len(ws) > 1:
self.fail('too many warnings caught: %s' % '\n '.join([str(w) for w in ws]))
self.assertTrue(type(ws[0].message) is category)
self.assertTrue(re.match(pattern, str(ws[0].message)),
f'{pattern}, {ws[0].message}')

def assertExpected(self, s, subname=None):
r"""
Test that a string matches the recorded contents of a file
Expand Down

0 comments on commit b97a040

Please sign in to comment.