Skip to content

Commit

Permalink
Add nondeterministic alert to index_copy, median CUDA and kthvalue CU…
Browse files Browse the repository at this point in the history
…DA (#46942)

Summary:
Also fixes issue where skipped tests did not properly restore deterministic flag.

Fixes #46743

Pull Request resolved: #46942

Reviewed By: heitorschueroff

Differential Revision: D25298020

Pulled By: mruberry

fbshipit-source-id: 14b1680e1fa536ec72018d0cdb0a3cf83b098767
  • Loading branch information
kurtamohler authored and facebook-github-bot committed Dec 3, 2020
1 parent c2ad3c4 commit 2cb9204
Show file tree
Hide file tree
Showing 9 changed files with 188 additions and 56 deletions.
3 changes: 3 additions & 0 deletions aten/src/ATen/native/TensorAdvancedIndexing.cpp
Expand Up @@ -333,6 +333,9 @@ Tensor & index_put_(Tensor & self, TensorList indices, const Tensor & value, con
}

Tensor & index_copy_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) {
// See note [Writing Nondeterministic Operations]
// Nondeterministic when index contains duplicate entries
at::globalContext().alertNotDeterministic("index_copy");
dim = maybe_wrap_dim(dim, self.dim());

TORCH_CHECK_INDEX(index.dim() < 2, "index_copy_(): Index should have dimension 1 or 0 (got ", index.dim(), ")");
Expand Down
8 changes: 8 additions & 0 deletions aten/src/ATen/native/cuda/Sorting.cu
Expand Up @@ -316,6 +316,10 @@ std::tuple<Tensor&, Tensor&> median_with_indices_impl(
int64_t dim,
bool keepdim,
bool ignore_nan) {
// See note [Writing Nondeterministic Operations]
// If there are duplicate elements of a median value, the procedure for choosing which
// of the duplicates to use for the indices output is nondeterministic.
at::globalContext().alertNotDeterministic("median CUDA with indices output");
NoNamesGuard guard;

dim = at::maybe_wrap_dim(dim, self.dim());
Expand Down Expand Up @@ -410,6 +414,10 @@ std::tuple<Tensor&, Tensor&> kthvalue_out_cuda(
int64_t k,
int64_t dim,
bool keepdim) {
// See note [Writing Nondeterministic Operations]
// If there are duplicate elements of the kth value, the procedure for choosing which
// of the duplicates to use for the indices output is nondeterministic.
at::globalContext().alertNotDeterministic("kthvalue CUDA");
auto result = [&]() {
NoNamesGuard guard;
// `kthvalue_out_impl_cuda` expects contiguous in input `self`.
Expand Down
87 changes: 81 additions & 6 deletions test/test_torch.py
Expand Up @@ -23,14 +23,15 @@
do_test_dtypes, IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, load_tests, slowTest,
skipCUDANonDefaultStreamIf, skipCUDAMemoryLeakCheckIf, BytesIOContext,
skipIfRocm, skipIfNoSciPy,
wrapDeterministicFlagAPITest)
wrapDeterministicFlagAPITest, DeterministicGuard)
from multiprocessing.reduction import ForkingPickler
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
skipCUDAIfNoMagma, skipCUDAIfRocm, skipCUDAIfNotRocm,
onlyCUDA, onlyCPU,
dtypes, dtypesIfCUDA, dtypesIfCPU, deviceCountAtLeast,
PYTORCH_CUDA_MEMCHECK, largeTensorTest, onlyOnCPUAndCUDA)
PYTORCH_CUDA_MEMCHECK, largeTensorTest, onlyOnCPUAndCUDA,
expectedAlertNondeterministic)
from typing import Dict, List
import torch.backends.quantized
import torch.testing._internal.data
Expand Down Expand Up @@ -2823,8 +2824,7 @@ def _rand_shape(self, dim, min_size, max_size):

@onlyCPU
def test_set_deterministic_beta_warning(self, device):
det = torch.is_deterministic()
try:
with DeterministicGuard(torch.is_deterministic()):
# Ensures setting to false does not throw a warning
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
Expand All @@ -2834,8 +2834,6 @@ def test_set_deterministic_beta_warning(self, device):
# Setting set_deterministic(True) throws a warning once per process
with self.maybeWarnsRegex(UserWarning, "torch.set_deterministic is in beta"):
torch.set_deterministic(True)
finally:
torch.set_deterministic(det)

@dtypes(torch.float32, torch.complex64)
def test_storage(self, device, dtype):
Expand Down Expand Up @@ -3483,6 +3481,29 @@ def _test_in_place_broadcastable(t0, t1, t2=None):
_test_in_place_broadcastable(small2, small_expanded, large_expanded)
_test_in_place_broadcastable(small2, small, large)

# Ensures that kthvalue throws nondeterministic alerts in the correct cases
@dtypes(torch.double)
def test_kthvalue_nondeterministic_alert(self, device, dtype):
@expectedAlertNondeterministic('kthvalue CUDA', 'cuda')
def test_func(slf, device, call_type):
S = 10
k = 5
a = torch.randn(S, device=device)
if call_type == 'function':
torch.kthvalue(a, k)
elif call_type == 'method':
a.kthvalue(k)
elif call_type == 'out':
values = torch.empty_like(a)
indices = torch.empty((), device=device, dtype=torch.long)
torch.kthvalue(a, k, out=(values, indices))
else:
self.fail(f"'{call_type}' is not a valid call type")

test_func(self, device, 'function')
test_func(self, device, 'method')
test_func(self, device, 'out')

def test_embedding_scalar_weight_error(self, device):
indices = torch.rand(2, 2, device=device).long()
weight = torch.tensor(1.0)
Expand All @@ -3503,6 +3524,37 @@ def run_test(x, y):
y[1] = 1.
run_test(x, y)

# Ensures that median throws nondeterministic alerts in the correct cases
@dtypes(torch.double)
def test_median_nondeterministic_alert(self, device, dtype):
def test_func(slf, device, call_type):
S = 10
a = torch.randn(S, device=device)
if call_type == 'function':
torch.median(a)
elif call_type == 'function with indices':
torch.median(a, 0)
elif call_type == 'method':
a.median()
elif call_type == 'method with indices':
a.median(0)
elif call_type == 'out with indices':
result = torch.empty_like(a)
indices = torch.empty((), dtype=torch.long, device=device)
torch.median(a, 0, out=(result, indices))
else:
self.fail(f"'{call_type}' is not a valid call type")

@expectedAlertNondeterministic('median CUDA with indices output', 'cuda')
def test_func_expect_error(slf, device, call_type):
test_func(slf, device, call_type)

test_func(self, device, 'function')
test_func_expect_error(self, device, 'function with indices')
test_func(self, device, 'method')
test_func_expect_error(self, device, 'method with indices')
test_func_expect_error(self, device, 'out with indices')

@skipCUDANonDefaultStreamIf(True)
def test_multinomial_alias(self, device):
# Get probs vector to use in setup
Expand Down Expand Up @@ -4289,6 +4341,29 @@ def test_index_copy(self, device):
c = torch.zeros(3)
self.assertRaises(IndexError, lambda: a.index_copy_(dim=1, index=torch.tensor([3]), source=c))

# Ensures that index_copy throws nondeterministic alerts in the correct cases
@onlyOnCPUAndCUDA
@dtypes(torch.double)
def test_index_copy_nondeterministic_alert(self, device, dtype):
@expectedAlertNondeterministic('index_copy')
def test_func(slf, device, call_type):
S = 10
a = torch.randn(S, device=device)
b = torch.randn(S, device=device)
index = torch.randint(S, (S,), device=device)
if call_type == 'function':
torch.index_copy(a, 0, index, b)
elif call_type == 'method':
a.index_copy(0, index, b)
elif call_type == 'method inplace':
a.index_copy_(0, index, b)
else:
self.fail(f"'{call_type}' is not a valid call type")

test_func(self, device, 'function')
test_func(self, device, 'method')
test_func(self, device, 'method inplace')

def test_index_fill(self, device):
for dt in torch.testing.get_all_dtypes():
if dt == torch.half or dt == torch.bfloat16 or dt.is_complex:
Expand Down
3 changes: 3 additions & 0 deletions torch/__init__.py
Expand Up @@ -374,10 +374,13 @@ def set_deterministic(d):
* :class:`torch.nn.EmbeddingBag` when called on a CUDA tensor that requires grad
* :func:`torch.scatter_add_` when called on a CUDA tensor
* :func:`torch.index_add_` when called on a CUDA tensor
* :func:`torch.index_copy`
* :func:`torch.index_select` when called on a CUDA tensor that requires grad
* :func:`torch.repeat_interleave` when called on a CUDA tensor that requires grad
* :func:`torch.histc` when called on a CUDA tensor
* :func:`torch.bincount` when called on a CUDA tensor
* :func:`torch.kthvalue` with called on a CUDA tensor
* :func:`torch.median` with indices output when called on a CUDA tensor
A handful of CUDA operations are nondeterministic if the CUDA version is
10.2 or greater, unless the environment variable `CUBLAS_WORKSPACE_CONFIG=:4096:8`
Expand Down
5 changes: 5 additions & 0 deletions torch/_tensor_docs.py
Expand Up @@ -1768,6 +1768,11 @@ def add_docstr_all(method, docstr):
length of :attr:`index` (which must be a vector), and all other dimensions must
match :attr:`self`, or an error will be raised.
.. note::
If :attr:`index` contains duplicate entries, multiple elements from
:attr:`tensor` will be copied to the same index of :attr:`self`. The result
is nondeterministic since it depends on which copy occurs last.
Args:
dim (int): dimension along which to index
index (LongTensor): indices of :attr:`tensor` to select from
Expand Down
5 changes: 5 additions & 0 deletions torch/_torch_docs.py
Expand Up @@ -3894,6 +3894,11 @@ def merge_dicts(*dicts):
(see :func:`torch.squeeze`), resulting in both the :attr:`values` and
:attr:`indices` tensors having 1 fewer dimension than the :attr:`input` tensor.
.. note::
When :attr:`input` is a CUDA tensor and there are multiple valid
:attr:`k` th values, this function may nondeterministically return
:attr:`indices` for any of them.
Args:
{input}
k (int): k for the k-th smallest element
Expand Down
33 changes: 15 additions & 18 deletions torch/testing/_internal/common_device_type.py
Expand Up @@ -11,7 +11,7 @@
import torch
from torch.testing._internal.common_utils import TestCase, TEST_WITH_ROCM, TEST_MKL, \
skipCUDANonDefaultStreamIf, TEST_WITH_ASAN, TEST_WITH_UBSAN, TEST_WITH_TSAN, \
IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU
IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, DeterministicGuard
from torch.testing._internal.common_cuda import _get_torch_cuda_version
from torch.testing import \
(get_all_dtypes)
Expand Down Expand Up @@ -790,24 +790,21 @@ def __call__(self, fn):
@wraps(fn)
def efail_fn(slf, device, *args, **kwargs):
if self.device_type is None or self.device_type == slf.device_type:
deterministic_restore = torch.is_deterministic()
torch.set_deterministic(True)
try:
if self.fn_has_device_arg:
fn(slf, device, *args, **kwargs)
with DeterministicGuard(True):
try:
if self.fn_has_device_arg:
fn(slf, device, *args, **kwargs)
else:
fn(slf, *args, **kwargs)
except RuntimeError as e:
if self.error_message not in str(e):
slf.fail(
'expected non-deterministic error message to start with "'
+ self.error_message
+ '" but got this instead: "' + str(e) + '"')
return
else:
fn(slf, *args, **kwargs)
except RuntimeError as e:
torch.set_deterministic(deterministic_restore)
if self.error_message not in str(e):
slf.fail(
'expected non-deterministic error message to start with "'
+ self.error_message
+ '" but got this instead: "' + str(e) + '"')
return
else:
torch.set_deterministic(deterministic_restore)
slf.fail('expected a non-deterministic error, but it was not raised')
slf.fail('expected a non-deterministic error, but it was not raised')

if self.fn_has_device_arg:
return fn(slf, device, *args, **kwargs)
Expand Down
6 changes: 6 additions & 0 deletions torch/testing/_internal/common_methods_invocations.py
Expand Up @@ -1105,6 +1105,8 @@ def method_tests():
('mean', (S, S, S), (), 'dtype', (True,), (), (), ident, {'dtype': torch.float64}),
('kthvalue', (S, S, S), (2,)),
('kthvalue', (S, S, S), (2, 1,), 'dim', (), [1]),
('kthvalue', (S, S, S), (2, 1,), 'dim_alert_nondeterministic', (), [1],
[expectedAlertNondeterministic('kthvalue CUDA', 'cuda')]),
('kthvalue', (S, S, S), (2, 1, True,), 'keepdim_dim', (), [1]),
('kthvalue', (S,), (2, 0,), 'dim_1d', (), [1]),
('kthvalue', (S,), (2, 0, True,), 'keepdim_dim_1d', (), [1]),
Expand All @@ -1123,6 +1125,8 @@ def method_tests():
('nanquantile', (), (0.5,), 'scalar'),
('median', (S, S, S), NO_ARGS),
('median', (S, S, S), (1,), 'dim', (), [0]),
('median', (S, S, S), (1,), 'dim_alert_nondeterministic', (), [0],
[expectedAlertNondeterministic('median CUDA with indices output', 'cuda')]),
('median', (S, S, S), (1, True,), 'keepdim_dim', (), [0]),
('median', (), NO_ARGS, 'scalar'),
('median', (), (0,), 'scalar_dim', (), [0]),
Expand Down Expand Up @@ -1446,6 +1450,8 @@ def method_tests():
('index_add', (S, S), (0, index_variable(2, S), (2, S)), 'alert_nondeterministic', (), [0],
[expectedAlertNondeterministic('index_add_cuda_', 'cuda')]),
('index_copy', (S, S), (0, index_perm_variable(2, S), (2, S)), 'dim', (), [0]),
('index_copy', (S, S), (0, index_perm_variable(2, S), (2, S)), 'dim_alert_nondeterministic', (), [0],
[expectedAlertNondeterministic('index_copy')]),
('index_copy', (), (0, torch.tensor([0], dtype=torch.int64), (1,)), 'scalar_input_dim', (), [0]),
('index_copy', (), (0, torch.tensor(0, dtype=torch.int64), ()), 'scalar_all_dim', (), [0]),
('index_fill', (S, S), (0, index_variable(2, S), 2), 'dim', (), [0]),
Expand Down
94 changes: 62 additions & 32 deletions torch/testing/_internal/common_utils.py
Expand Up @@ -398,43 +398,73 @@ def wrapper(*args, **kwargs):
fn(*args, **kwargs)
return wrapper

# Context manager for setting deterministic flag and automatically
# resetting it to its original value
class DeterministicGuard:
def __init__(self, deterministic):
self.deterministic = deterministic

def __enter__(self):
self.deterministic_restore = torch.is_deterministic()
torch.set_deterministic(self.deterministic)

def __exit__(self, exception_type, exception_value, traceback):
torch.set_deterministic(self.deterministic_restore)

# This decorator can be used for API tests that call torch.set_deterministic().
# When the test is finished, it will restore the previous deterministic flag
# setting. Also, if CUDA >= 10.2, this will set the environment variable
# CUBLAS_WORKSPACE_CONFIG=:4096:8 so that the error associated with that setting
# is not thrown during the test unless the test changes that variable on purpose.
# The previous CUBLAS_WORKSPACE_CONFIG setting will also be restored once the
# test is finished.
# setting.
#
# If CUDA >= 10.2, this will set the environment variable
# CUBLAS_WORKSPACE_CONFIG=:4096:8 so that the error associated with that
# setting is not thrown during the test unless the test changes that variable
# on purpose. The previous CUBLAS_WORKSPACE_CONFIG setting will also be
# restored once the test is finished.
#
# Note that if a test requires CUDA to actually register the changed
# CUBLAS_WORKSPACE_CONFIG variable, a new subprocess must be created, because
# CUDA only checks the variable when the runtime initializes. Tests can be
# run inside a subprocess like so:
#
# import subprocess, sys, os
# script = '''
# # Test code should go here
# '''
# try:
# subprocess.check_output(
# [sys.executable, '-c', script],
# stderr=subprocess.STDOUT,
# cwd=os.path.dirname(os.path.realpath(__file__)),
# env=os.environ.copy())
# except subprocess.CalledProcessError as e:
# error_message = e.output.decode('utf-8')
# # Handle exceptions raised by the subprocess here
#
def wrapDeterministicFlagAPITest(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
deterministic_restore = torch.is_deterministic()

is_cuda10_2_or_higher = (
(torch.version.cuda is not None)
and ([int(x) for x in torch.version.cuda.split(".")] >= [10, 2]))

if is_cuda10_2_or_higher:
cublas_var_name = 'CUBLAS_WORKSPACE_CONFIG'
cublas_config_restore = os.environ.get(cublas_var_name)
os.environ[cublas_var_name] = ':4096:8'

def restore():
torch.set_deterministic(deterministic_restore)
if is_cuda10_2_or_higher:
cur_cublas_config = os.environ.get(cublas_var_name)
if cublas_config_restore is None:
if cur_cublas_config is not None:
del os.environ[cublas_var_name]
else:
os.environ[cublas_var_name] = cublas_config_restore
try:
fn(*args, **kwargs)
except RuntimeError:
restore()
raise
else:
restore()
with DeterministicGuard(torch.is_deterministic()):
class CuBLASConfigGuard:
cublas_var_name = 'CUBLAS_WORKSPACE_CONFIG'

def __enter__(self):
self.is_cuda10_2_or_higher = (
(torch.version.cuda is not None)
and ([int(x) for x in torch.version.cuda.split(".")] >= [10, 2]))
if self.is_cuda10_2_or_higher:
self.cublas_config_restore = os.environ.get(self.cublas_var_name)
os.environ[self.cublas_var_name] = ':4096:8'

def __exit__(self, exception_type, exception_value, traceback):
if self.is_cuda10_2_or_higher:
cur_cublas_config = os.environ.get(self.cublas_var_name)
if self.cublas_config_restore is None:
if cur_cublas_config is not None:
del os.environ[self.cublas_var_name]
else:
os.environ[self.cublas_var_name] = self.cublas_config_restore
with CuBLASConfigGuard():
fn(*args, **kwargs)
return wrapper

def skipIfCompiledWithoutNumpy(fn):
Expand Down

0 comments on commit 2cb9204

Please sign in to comment.