Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reland: Make torch.empty* deterministic by filling with NaN or max int #104995

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 11 additions & 1 deletion aten/src/ATen/mps/EmptyTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <torch/library.h>
#include <ATen/mps/MPSDevice.h>
#include <ATen/native/Resize.h>
#include <ATen/native/TensorFactories.h>
#include <ATen/native/mps/Copy.h>

#define MPS_ERROR_NOT_COMPILED "PyTorch code is not compiled with MPS enabled"
Expand Down Expand Up @@ -63,6 +64,10 @@ TensorBase empty_mps(

auto memory_format = memory_format_opt.value_or(MemoryFormat::Contiguous);
tensor.unsafeGetTensorImpl()->empty_tensor_restride(memory_format);
// See Note [Enabling Deterministic Operations]
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
at::native::fill_empty_deterministic_(tensor);
}
return tensor;
} else {
TORCH_CHECK(false, MPS_ERROR_RUNTIME_TOO_LOW)
Expand Down Expand Up @@ -100,8 +105,13 @@ TensorBase empty_strided_mps(
const DeviceGuard device_guard(device);
auto* allocator = at::mps::GetMPSAllocator();
constexpr c10::DispatchKeySet mps_dks(c10::DispatchKey::MPS);
return at::detail::empty_strided_generic(
Tensor result = at::detail::empty_strided_generic(
size, stride, allocator, mps_dks, dtype);
// See Note [Enabling Deterministic Operations]
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
at::native::fill_empty_deterministic_(result);
}
return result;
} else {
TORCH_CHECK(false, MPS_ERROR_RUNTIME_TOO_LOW)
}
Expand Down
18 changes: 16 additions & 2 deletions aten/src/ATen/native/TensorFactories.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,12 @@ Tensor polar(const Tensor& abs, const Tensor& angle) {
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ empty ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Tensor empty_cpu(IntArrayRef size, c10::optional<ScalarType> dtype_opt, c10::optional<Layout> layout_opt,
c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt, c10::optional<c10::MemoryFormat> memory_format_opt) {
return at::detail::empty_cpu(size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt);
Tensor result = at::detail::empty_cpu(size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt);
// See Note [Enabling Deterministic Operations]
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
fill_empty_deterministic_(result);
}
return result;
}

Tensor empty_names(
Expand Down Expand Up @@ -320,7 +325,12 @@ Tensor empty_permuted_symint(SymIntArrayRef size, IntArrayRef physical_layout, c

Tensor empty_strided_cpu(IntArrayRef size, IntArrayRef stride, c10::optional<ScalarType> dtype_opt,
c10::optional<Layout> layout_opt, c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt) {
return at::detail::empty_strided_cpu(size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt);
Tensor result = at::detail::empty_strided_cpu(size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt);
// See Note [Enabling Deterministic Operations]
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
fill_empty_deterministic_(result);
}
return result;
}

Tensor& empty_out(IntArrayRef size,
Expand All @@ -337,6 +347,10 @@ Tensor& empty_out(IntArrayRef size,
} else {
result.resize_(size);
}
// See Note [Enabling Deterministic Operations]
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
fill_empty_deterministic_(result);
}
return result;
}

Expand Down
15 changes: 13 additions & 2 deletions aten/src/ATen/native/cuda/TensorFactories.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <ATen/native/TensorFactories.h>
#include <c10/util/accumulate.h>
#include <c10/util/Exception.h>
#include <ATen/native/cuda/Loops.cuh>

#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
Expand Down Expand Up @@ -51,7 +52,12 @@ Tensor& eye_out_cuda(int64_t n, int64_t m, Tensor& result) {
}

Tensor empty_cuda(IntArrayRef size, c10::optional<ScalarType> dtype_opt, c10::optional<Layout> layout_opt, c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt, c10::optional<c10::MemoryFormat> memory_format_opt) {
return at::detail::empty_cuda(size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt);
Tensor result = at::detail::empty_cuda(size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt);
// See Note [Enabling Deterministic Operations]
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
fill_empty_deterministic_(result);
}
return result;
}

Tensor _efficientzerotensor_cuda(IntArrayRef size,
Expand All @@ -72,7 +78,12 @@ Tensor _efficientzerotensor_cuda(IntArrayRef size,


Tensor empty_strided_cuda(IntArrayRef size, IntArrayRef stride, c10::optional<ScalarType> dtype_opt, c10::optional<Layout> layout_opt, c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt) {
return at::detail::empty_strided_cuda(size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt);
Tensor result = at::detail::empty_strided_cuda(size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt);
// See Note [Enabling Deterministic Operations]
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
fill_empty_deterministic_(result);
}
return result;
}

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triangle ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
28 changes: 28 additions & 0 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1342,6 +1342,34 @@ def test_deterministic_resize(self, device, dtype):
else:
self.assertEqual(old_tensor, new_tensor)

# When deterministic algorithms are enabled, `torch.empty` should fill floating
# point tensors with NaN and integer tensors with MAX_INT
@skipXLA
@skipIfTorchInductor("aot-autograd issue")
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
def test_deterministic_empty(self, device, dtype):
gen_fns = [
lambda: torch.empty(10, 9, device=device, dtype=dtype),
lambda: torch.empty(10, 9, out=torch.zeros(1, device=device, dtype=dtype)),
lambda: torch.empty_like(torch.zeros(10, 9, device=device, dtype=dtype)),
lambda: torch.empty_like(torch.zeros(10, 9, device=device, dtype=dtype), memory_format=torch.contiguous_format),
lambda: torch.empty_strided((10, 9), (1, 5), device=device, dtype=dtype),
lambda: torch.empty_permuted((2, 3, 5), (1, 0, 2), device=device, dtype=dtype),
]

for gen_fn in gen_fns:
with DeterministicGuard(True):
res = gen_fn()

if dtype.is_floating_point or dtype.is_complex:
self.assertTrue(res.isnan().all())
else:
if dtype == torch.bool:
max_val = True
else:
max_val = torch.iinfo(dtype).max
self.assertTrue(res.eq(max_val).all())

# FIXME: update OpInfos to support "nondeterministic samples" and port these tests
# to that architecture
@skipIfMps
Expand Down
4 changes: 4 additions & 0 deletions torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,6 +704,10 @@ def use_deterministic_algorithms(mode, *, warn_only=False):
quantized, sets new elements to a known value. Floating point or
complex values are set to NaN. Integer values are set to the maximum
value.
* :func:`torch.empty`, :func:`torch.empty_like`, :func:`torch.empty_strided`,
and :func:`torch.empty_permuted` will fill the output tensor with a known
value. Floating point or complex dtype tensors are filled with NaN. Integer
dtype tensors are filled with the maximum value.

The following normally-nondeterministic operations will throw a
:class:`RuntimeError` when ``mode=True``:
Expand Down
28 changes: 28 additions & 0 deletions torch/_torch_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12277,6 +12277,13 @@ def merge_dicts(*dicts):
Returns a tensor filled with uninitialized data. The shape of the tensor is
defined by the variable argument :attr:`size`.

.. note::
If :func:`torch.use_deterministic_algorithms()` is set to ``True``, the
output tensor is initialized to prevent any possible nondeterministic
behavior from using the data as an input to an operation. Floating point
and complex tensors are filled with NaN, and integer tensors are filled
with the maximum value.

Args:
size (int...): a sequence of integers defining the shape of the output tensor.
Can be a variable number of arguments or a collection like a list or tuple.
Expand Down Expand Up @@ -12309,6 +12316,13 @@ def merge_dicts(*dicts):
``torch.empty_like(input)`` is equivalent to
``torch.empty(input.size(), dtype=input.dtype, layout=input.layout, device=input.device)``.

.. note::
If :func:`torch.use_deterministic_algorithms()` is set to ``True``, the
output tensor is initialized to prevent any possible nondeterministic
behavior from using the data as an input to an operation. Floating point
and complex tensors are filled with NaN, and integer tensors are filled
with the maximum value.

Args:
{input}

Expand Down Expand Up @@ -12341,6 +12355,13 @@ def merge_dicts(*dicts):
If the constructed tensor is "overlapped" (with multiple indices referring to the same element
in memory) its behavior is undefined.

.. note::
If :func:`torch.use_deterministic_algorithms()` is set to ``True``, the
output tensor is initialized to prevent any possible nondeterministic
behavior from using the data as an input to an operation. Floating point
and complex tensors are filled with NaN, and integer tensors are filled
with the maximum value.

Args:
size (tuple of int): the shape of the output tensor
stride (tuple of int): the strides of the output tensor
Expand Down Expand Up @@ -12386,6 +12407,13 @@ def merge_dicts(*dicts):
tensor with no overlaps. If possible, prefer using this function over
:func:`torch.empty_strided` or manual use of :func:`torch.as_strided`.

.. note::
If :func:`torch.use_deterministic_algorithms()` is set to ``True``, the
output tensor is initialized to prevent any possible nondeterministic
behavior from using the data as an input to an operation. Floating point
and complex tensors are filled with NaN, and integer tensors are filled
with the maximum value.

Args:
size (tuple of int): the shape of the output tensor
physical_layout (tuple of int): the ordering of dimensions physically in memory
Expand Down