Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename set_deterministic to use_deterministic_algorithms #49904

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
22 changes: 11 additions & 11 deletions aten/src/ATen/Context.cpp
Expand Up @@ -60,25 +60,25 @@ void Context::setDeterministicCuDNN(bool b) {
deterministic_cudnn = b;
}

bool Context::deterministic() const {
return _deterministic;
bool Context::deterministicAlgorithms() const {
return _deterministic_algorithms;
}

void Context::setDeterministic(bool b) {
void Context::setDeterministicAlgorithms(bool b) {
if (b) {
TORCH_WARN_ONCE("torch.set_deterministic is in beta, and its design and "
TORCH_WARN_ONCE("torch.use_deterministic_algorithms is in beta, and its design and"
" functionality may change in the future.");
}

_deterministic = b;
_deterministic_algorithms = b;
}

void Context::alertNotDeterministic(c10::string_view const& caller) {
if (globalContext().deterministic()) {
if (globalContext().deterministicAlgorithms()) {
TORCH_CHECK(false,
caller, " does not have a deterministic implementation, but you set "
"'torch.set_deterministic(True)'. You can turn off determinism just "
"for this operation if that's acceptable for your application. You "
"'torch.use_deterministic_algorithms(True)'. You can turn off determinism ",
"just for this operation if that's acceptable for your application. You "
"can also file an issue at https://github.com/pytorch/pytorch/issues "
"to help us prioritize adding deterministic support for this operation.");
}
Expand Down Expand Up @@ -111,9 +111,9 @@ bool Context::checkCuBLASConfigDeterministic() {

void Context::alertCuBLASConfigNotDeterministic() {
static bool cublas_config_deterministic = checkCuBLASConfigDeterministic();
TORCH_CHECK(!deterministic() || cublas_config_deterministic,
"Deterministic behavior was enabled with either `torch.set_deterministic(True)` or ",
"`at::Context::setDeterministic(true)`, but this operation is not deterministic because ",
TORCH_CHECK(!deterministicAlgorithms() || cublas_config_deterministic,
"Deterministic behavior was enabled with either `torch.use_deterministic_algorithms(True)` or ",
"`at::Context::setDeterministicAlgorithms(true)`, but this operation is not deterministic because ",
"it uses CuBLAS and you have CUDA >= 10.2. To enable deterministic behavior in this ",
"case, you must set an environment variable before running your PyTorch application: ",
cublas_config_var_name, "=", cublas_deterministic_configs[0], " or ",
Expand Down
34 changes: 18 additions & 16 deletions aten/src/ATen/Context.h
Expand Up @@ -120,27 +120,27 @@ class TORCH_API Context {
//
// * Include this comment: "See Note [Enabling Deterministic Operations]"
//
// * Check the value of `at::globalContext().deterministic()` to toggle between
// nondeterministic and deterministic implementations.
// * Check the value of `at::globalContext().deterministicAlgorithms()` to toggle
// between nondeterministic and deterministic implementations.
//
// * Have an entry in the list of PyTorch operations that toggle between nondeterministic
// and deterministic implementations, in the docstring of `set_deterministic()`
// and deterministic implementations, in the docstring of `use_deterministic_algorithms()`
// in torch/__init__.py
//
// `example_func()` below shows an example of toggling between nondeterministic and
// deterministic implementations:
//
// void example_func() {
// // See Note [Enabling Deterministic Operations]
// if (at::globalContext().deterministic()) {
// if (at::globalContext().deterministicAlgorithms()) {
// example_func_deterministic();
// } else {
// example_func_nondeterministic();
// }
// }

bool deterministic() const;
void setDeterministic(bool);
bool deterministicAlgorithms() const;
void setDeterministicAlgorithms(bool);

// Note [Writing Nondeterministic Operations]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand All @@ -151,16 +151,18 @@ class TORCH_API Context {
//
// * Include a comment explaining why the operation is nondeterministic.
//
// * Throw an error when `Context::deterministic()` is true. Most of the time, this
// should be accomplished by calling `at::globalContext().alertNotDeterminstic()`.
// However, if the nondeterministic behavior is caused by the CuBLAS workspace
// * Throw an error when `Context::deterministicAlgorithms()` is true. Most
// of the time, this should be accomplished by calling
// `at::globalContext().alertNotDeterminstic()`. However, if the
// nondeterministic behavior is caused by the CuBLAS workspace
// configuration in CUDA >= 10.2,
// `at::globalContext().alertCuBLASConfigNotDeterministic()` should
// be called instead (in this case, a comment explaining why the operation is
// nondeterministic is not necessary). See below for details on these methods.
// `at::globalContext().alertCuBLASConfigNotDeterministic()` should be
// called instead (in this case, a comment explaining why the operation is
// nondeterministic is not necessary). See below for details on these
// methods.
//
// * Have an entry in the list of nondeterministic PyTorch operations in the
// docstring of `set_deterministic()` in torch/__init__.py
// docstring of `use_deterministic_algorithms()` in torch/__init__.py
//
// `example_func()` below shows an example of the comments and error-throwing code
// for a nondeterministic operation:
Expand All @@ -172,10 +174,10 @@ class TORCH_API Context {
// ...
// }

// Throws an error if `Context::deterministic()` is true
// Throws an error if `Context::deterministicAlgorithms()` is true
void alertNotDeterministic(c10::string_view const& caller);

// Throws an error if `Context::deterministic()` is true, CUDA >= 10.2, and
// Throws an error if `Context::deterministicAlgorithms()` is true, CUDA >= 10.2, and
// CUBLAS_WORKSPACE_CONFIG is not set to either ":16:8" or ":4096:8". For more details:
// https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility
void alertCuBLASConfigNotDeterministic();
Expand Down Expand Up @@ -210,7 +212,7 @@ class TORCH_API Context {
std::once_flag thh_init;
bool enabled_cudnn = true;
bool deterministic_cudnn = false;
bool _deterministic = false;
bool _deterministic_algorithms = false;
bool benchmark_cudnn = false;
bool allow_tf32_cudnn = true;
bool allow_tf32_cublas = true;
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/cuda/CublasHandlePool.cpp
Expand Up @@ -53,7 +53,7 @@ cublasHandle_t getCurrentCUDABlasHandle() {
#endif
#if defined(__HIP_PLATFORM_HCC__) && HIP_VERSION >= 308
rocblas_atomics_mode rocblas_mode;
if (at::globalContext().deterministic()) {
if (at::globalContext().deterministicAlgorithms()) {
rocblas_mode = rocblas_atomics_not_allowed;
} else {
rocblas_mode = rocblas_atomics_allowed;
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/Convolution.cpp
Expand Up @@ -588,7 +588,7 @@ at::Tensor convolution(
bool transposed, IntArrayRef output_padding, int64_t groups) {
auto& ctx = at::globalContext();
// See Note [Enabling Deterministic Operations]
bool deterministic = ctx.deterministicCuDNN() || ctx.deterministic();
bool deterministic = ctx.deterministicCuDNN() || ctx.deterministicAlgorithms();
return at::_convolution(input, weight, bias, stride, padding, dilation,
transposed, output_padding, groups,
ctx.benchmarkCuDNN(), deterministic, ctx.userEnabledCuDNN(), ctx.allowTF32CuDNN());
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu
Expand Up @@ -889,7 +889,7 @@ Tensor& _bmm_out_sparse_cuda(Tensor& result, const SparseTensor& self, const Ten
void* workspace_buffer = nullptr;

// See Note [Enabling Deterministic Operations]
deterministic = deterministic || globalContext().deterministic();
deterministic = deterministic || globalContext().deterministicAlgorithms();
cusparseSpMMAlg_t mm_alg = deterministic ? CUSPARSE_COOMM_ALG2 : CUSPARSE_COOMM_ALG1;

// Iterate through each set of 2D matrices within the 3D
Expand Down
3 changes: 2 additions & 1 deletion docs/source/backends.rst
Expand Up @@ -61,7 +61,8 @@ torch.backends.cudnn
.. attribute:: torch.backends.cudnn.deterministic

A :class:`bool` that, if True, causes cuDNN to only use deterministic convolution algorithms.
See also :func:`torch.is_deterministic` and :func:`torch.set_deterministic`.
See also :func:`torch.are_deterministic_algorithms_enabled` and
:func:`torch.use_deterministic_algorithms`.

.. attribute:: torch.backends.cudnn.benchmark

Expand Down
37 changes: 19 additions & 18 deletions docs/source/notes/randomness.rst
Expand Up @@ -79,34 +79,34 @@ setting discussed below.

Avoiding nondeterministic algorithms
....................................
:meth:`torch.set_deterministic` lets you configure PyTorch to use deterministic
algorithms instead of nondeterministic ones where available, and to throw an error
if an operation is known to be nondeterministic (and without a deterministic
alternative).

Please check the documentation for :meth:`torch.set_deterministic()` for a full
list of affected operations. If an operation does not act correctly according to
the documentation, or if you need a deterministic implementation of an operation
that does not have one, please submit an issue:
:meth:`torch.use_deterministic_algorithms` lets you configure PyTorch to use
deterministic algorithms instead of nondeterministic ones where available, and
to throw an error if an operation is known to be nondeterministic (and without
a deterministic alternative).

Please check the documentation for :meth:`torch.use_deterministic_algorithms()`
for a full list of affected operations. If an operation does not act correctly
according to the documentation, or if you need a deterministic implementation
of an operation that does not have one, please submit an issue:
`<https://github.com/pytorch/pytorch/issues?q=label:%22topic:%20determinism%22>`_

For example, running the nondeterministic CUDA implementation of :meth:`torch.Tensor.index_add_`
will throw an error::

>>> import torch
>>> torch.set_deterministic(True)
>>> torch.use_deterministic_algorithms(True)
>>> torch.randn(2, 2).cuda().index_add_(0, torch.tensor([0, 1]), torch.randn(2, 2))
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: index_add_cuda_ does not have a deterministic implementation, but you set
'torch.set_deterministic(True)'. ...
'torch.use_deterministic_algorithms(True)'. ...

When :meth:`torch.bmm` is called with sparse-dense CUDA tensors it typically uses a
nondeterministic algorithm, but when the deterministic flag is turned on, its alternate
deterministic implementation will be used::

>>> import torch
>>> torch.set_deterministic(True)
>>> torch.use_deterministic_algorithms(True)
>>> torch.bmm(torch.randn(2, 2, 2).to_sparse().cuda(), torch.randn(2, 2, 2).cuda())
tensor([[[ 1.1900, -2.3409],
[ 0.4796, 0.8003]],
Expand All @@ -119,12 +119,13 @@ should set the environment variable `CUBLAS_WORKSPACE_CONFIG` according to CUDA

CUDA convolution determinism
----------------------------
While disabling CUDA convolution benchmarking (discussed above) ensures that CUDA
selects the same algorithm each time an application is run, that algorithm itself
may be nondeterministic, unless either :code:`torch.set_deterministic(True)` or
:code:`torch.backends.cudnn.deterministic = True` is set. The latter setting controls
only this behavior, unlike :meth:`torch.set_deterministic` which will make other
PyTorch operations behave deterministically, too.
While disabling CUDA convolution benchmarking (discussed above) ensures that
CUDA selects the same algorithm each time an application is run, that algorithm
itself may be nondeterministic, unless either
:code:`torch.use_deterministic_algorithms(True)` or
:code:`torch.backends.cudnn.deterministic = True` is set. The latter setting
controls only this behavior, unlike :meth:`torch.use_deterministic_algorithms`
which will make other PyTorch operations behave deterministically, too.

CUDA RNN and LSTM
-----------------
Expand Down
4 changes: 2 additions & 2 deletions docs/source/torch.rst
Expand Up @@ -555,7 +555,7 @@ Utilities
result_type
can_cast
promote_types
set_deterministic
is_deterministic
use_deterministic_algorithms
are_deterministic_algorithms_enabled
vmap
_assert
2 changes: 1 addition & 1 deletion test/test_linalg.py
Expand Up @@ -4838,7 +4838,7 @@ def test_case_info(fn_name, config):
should_throw_error = is_cuda10_2_or_higher and not is_config_deterministic
script = f"""
import torch
torch.set_deterministic(True)
torch.use_deterministic_algorithms(True)
fn = torch.{fn_name}
arg_sizes = {arg_sizes}
device = '{device}'
Expand Down
32 changes: 23 additions & 9 deletions test/test_torch.py
Expand Up @@ -109,11 +109,11 @@ def test_dir(self):
@wrapDeterministicFlagAPITest
def test_deterministic_flag(self):
for deterministic in [True, False]:
torch.set_deterministic(deterministic)
self.assertEqual(deterministic, torch.is_deterministic())
torch.use_deterministic_algorithms(deterministic)
self.assertEqual(deterministic, torch.are_deterministic_algorithms_enabled())

with self.assertRaisesRegex(RuntimeError, r"set_deterministic expects a bool, but got int"):
torch.set_deterministic(1)
with self.assertRaisesRegex(RuntimeError, r"use_deterministic_algorithms expects a bool, but got int"):
torch.use_deterministic_algorithms(1)

def test_type_conversion_via_dtype_name(self):
x = torch.tensor([1])
Expand Down Expand Up @@ -2833,18 +2833,32 @@ def _rand_shape(self, dim, min_size, max_size):
return tuple(shape)

@onlyCPU
def test_set_deterministic_beta_warning(self, device):
with DeterministicGuard(torch.is_deterministic()):
def test_use_deterministic_algorithms_beta_warning(self, device):
with DeterministicGuard(torch.are_deterministic_algorithms_enabled()):
# Ensures setting to false does not throw a warning
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
torch.set_deterministic(False)
torch.use_deterministic_algorithms(False)
self.assertEqual(len(w), 0)

# Setting set_deterministic(True) throws a warning once per process
with self.maybeWarnsRegex(UserWarning, "torch.set_deterministic is in beta"):
# Setting use_deterministic_algorithms(True) throws a warning once per process
with self.maybeWarnsRegex(UserWarning, "torch.use_deterministic_algorithms is in beta"):
torch.use_deterministic_algorithms(True)

@onlyCPU
def test_set_deterministic_deprecated_warning(self, device):
with DeterministicGuard(torch.are_deterministic_algorithms_enabled()):
# Calling set_deterministic throws a warning about deprecation once per process
with self.maybeWarnsRegex(UserWarning, "torch.set_deterministic is deprecated"):
torch.set_deterministic(True)

@onlyCPU
def test_is_deterministic_deprecated_warning(self, device):
with DeterministicGuard(torch.are_deterministic_algorithms_enabled()):
# Calling is_deterministic throws a warning about deprecation once per process
with self.maybeWarnsRegex(UserWarning, "torch.is_deterministic is deprecated"):
torch.is_deterministic()

@dtypes(torch.float32, torch.complex64)
def test_storage(self, device, dtype):
v = torch.randn(3, 5, dtype=dtype, device=device)
Expand Down
4 changes: 2 additions & 2 deletions torch/_C/__init__.pyi.in
Expand Up @@ -483,8 +483,8 @@ def _get_cudnn_benchmark() -> _bool: ... # THPModule_benchmarkCuDNN
def _set_cudnn_benchmark(arg: _bool) -> None: ... # THPModule_setBenchmarkCuDNN
def _get_cudnn_deterministic() -> _bool: ... # THPModule_deterministicCuDNN
def _set_cudnn_deterministic(arg: _bool) -> None: ... # THPModule_setDeterministicCuDNN
def _get_deterministic() -> _bool: ... # THPModule_deterministic
def _set_deterministic(arg: _bool) -> None: ... # THPModule_setDeterministic
def _get_deterministic_algorithms() -> _bool: ... # THPModule_deterministicAlgorithms
def _set_deterministic_algorithms(arg: _bool) -> None: ... # THPModule_setDeterministicAlgorithms
def _get_cudnn_allow_tf32() -> _bool: ... # THPModule_allowTF32CuDNN
def _set_cudnn_allow_tf32(arg: _bool) -> None: ... # THPModule_setAllowTF32CuDNN
def _get_cublas_allow_tf32() -> _bool: ... # THPModule_allowTF32CuBLAS
Expand Down
34 changes: 28 additions & 6 deletions torch/__init__.py
Expand Up @@ -14,6 +14,7 @@
import platform
import textwrap
import ctypes
import warnings

if sys.version_info < (3,):
raise Exception("Python 2 has reached end-of-life and is no longer supported by PyTorch.")
Expand All @@ -35,7 +36,8 @@
'ShortStorage', 'CharStorage', 'ByteStorage', 'BoolStorage',
'DoubleTensor', 'FloatTensor', 'LongTensor', 'IntTensor',
'ShortTensor', 'CharTensor', 'ByteTensor', 'BoolTensor', 'Tensor',
'lobpcg', 'set_deterministic', 'is_deterministic'
'lobpcg', 'use_deterministic_algorithms', 'set_deterministic',
'are_deterministic_algorithms_enabled', 'is_deterministic'
]

################################################################################
Expand Down Expand Up @@ -325,7 +327,7 @@ def set_default_dtype(d):
"""
_C._set_default_dtype(d)

def set_deterministic(d):
def use_deterministic_algorithms(d):
r""" Sets whether PyTorch operations must use "deterministic"
algorithms. That is, algorithms which, given the same input, and when
run on the same software and hardware, always produce the same output.
Expand Down Expand Up @@ -402,13 +404,33 @@ def set_deterministic(d):
d (:class:`bool`): If True, force operations to be deterministic.
If False, allow non-deterministic operations.
"""
_C._set_deterministic(d)
_C._set_deterministic_algorithms(d)

def is_deterministic():
def set_deterministic(d):
r"""This function is deprecated and will be removed in a future release.
Please use :func:`torch.use_deterministic_algorithms` instead.
"""
warnings.warn((
"torch.set_deterministic is deprecated and will be removed in a future "
"release. Please use torch.use_deterministic_algorithms instead"))

use_deterministic_algorithms(d)

def are_deterministic_algorithms_enabled():
r"""Returns True if the global deterministic flag is turned on. Refer to
:func:`torch.set_deterministic` documentation for more details.
:func:`torch.use_deterministic_algorithms` documentation for more details.
"""
return _C._get_deterministic()
return _C._get_deterministic_algorithms()

def is_deterministic():
r"""This function is deprecated and will be removed in a future release.
Please use :func:`torch.are_deterministic_algorithms_enabled` instead.
"""
warnings.warn((
"torch.is_deterministic is deprecated and will be removed in a future "
"release. Please use torch.are_deterministic_algorithms_enabled instead"))
return are_deterministic_algorithms_enabled()


################################################################################
# Define Storage and Tensor classes
Expand Down