From c41678fd53fd145af57a32b38fcdda5d6e6754c2 Mon Sep 17 00:00:00 2001 From: Kurt Mohler Date: Wed, 3 Feb 2021 21:31:11 -0800 Subject: [PATCH] Use deterministic impl of `index_put` and `index` backward CPU when `torch.are_deterministic_algorithms_enabled() == True` (#51388) Summary: Fixes https://github.com/pytorch/pytorch/issues/51366 Pull Request resolved: https://github.com/pytorch/pytorch/pull/51388 Reviewed By: zou3519 Differential Revision: D26235290 Pulled By: ngimel fbshipit-source-id: 64cce1a5e75d8a9ce9807c28d641da82ede666e2 --- aten/src/ATen/native/cpu/IndexKernel.cpp | 9 +++++++-- torch/__init__.py | 4 ++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/native/cpu/IndexKernel.cpp b/aten/src/ATen/native/cpu/IndexKernel.cpp index f0947de8c699..cf524d619571 100644 --- a/aten/src/ATen/native/cpu/IndexKernel.cpp +++ b/aten/src/ATen/native/cpu/IndexKernel.cpp @@ -111,8 +111,13 @@ void index_put_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(ScalarType::Half, ScalarType::Bool, ScalarType::BFloat16, iter.dtype(), "index_put", [&] { if (accumulate) { - bool use_parallel_for = ((iter.numel() >= internal::GRAIN_SIZE) && (at::get_num_threads() > 1)); - if (iter.dtype() == ScalarType::Float && use_parallel_for) { + // See Note [Enabling Deterministic Operations] + // Parallel cpu_index_kernel with accumulation is nondeterministic, so we + // must enable serial execution if deterministic algorithms are enabled. + bool is_deterministic = at::globalContext().deterministicAlgorithms(); + bool use_parallel_for = (!is_deterministic) && ( + (iter.numel() >= internal::GRAIN_SIZE) && (at::get_num_threads() > 1)); + if (use_parallel_for && iter.dtype() == ScalarType::Float) { cpu_index_kernel(iter, index_size, index_stride, [](char* dst, char* src, int64_t offset) { cpu_atomic_add_float((float*)(dst + offset), *(float*)src); }); diff --git a/torch/__init__.py b/torch/__init__.py index f27af91eb493..1b0189a11059 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -353,6 +353,10 @@ def use_deterministic_algorithms(d): * :class:`torch.nn.ConvTranspose2d` when called on CUDA tensor * :class:`torch.nn.ConvTranspose3d` when called on CUDA tensor * :func:`torch.bmm` when called on sparse-dense CUDA tensors + * :func:`torch.__getitem__` backward when `self` is a CPU tensor and + ``indices`` is a list of tensors + * :func:`torch.index_put` with ``accumulate=True`` when called on a CPU + tensor The following normally-nondeterministic operations will throw a :class:`RuntimeError` when `d=True`: