Skip to content

Commit

Permalink
Reland: Make torch.empty* deterministic by filling with NaN or max …
Browse files Browse the repository at this point in the history
…int (#104995)

Relands #101849 after #104302 reverted it.

torchrec PR pytorch/torchrec#1269 fixes the torchrec failure that caused #101849 to be reverted

Part of #82004

Pull Request resolved: #104995
Approved by: https://github.com/albanD
  • Loading branch information
kurtamohler authored and pytorchmergebot committed Jul 13, 2023
1 parent 42530c1 commit f987d11
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 5 deletions.
12 changes: 11 additions & 1 deletion aten/src/ATen/mps/EmptyTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <torch/library.h>
#include <ATen/mps/MPSDevice.h>
#include <ATen/native/Resize.h>
#include <ATen/native/TensorFactories.h>
#include <ATen/native/mps/Copy.h>

#define MPS_ERROR_NOT_COMPILED "PyTorch code is not compiled with MPS enabled"
Expand Down Expand Up @@ -63,6 +64,10 @@ TensorBase empty_mps(

auto memory_format = memory_format_opt.value_or(MemoryFormat::Contiguous);
tensor.unsafeGetTensorImpl()->empty_tensor_restride(memory_format);
// See Note [Enabling Deterministic Operations]
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
at::native::fill_empty_deterministic_(tensor);
}
return tensor;
} else {
TORCH_CHECK(false, MPS_ERROR_RUNTIME_TOO_LOW)
Expand Down Expand Up @@ -100,8 +105,13 @@ TensorBase empty_strided_mps(
const DeviceGuard device_guard(device);
auto* allocator = at::mps::GetMPSAllocator();
constexpr c10::DispatchKeySet mps_dks(c10::DispatchKey::MPS);
return at::detail::empty_strided_generic(
Tensor result = at::detail::empty_strided_generic(
size, stride, allocator, mps_dks, dtype);
// See Note [Enabling Deterministic Operations]
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
at::native::fill_empty_deterministic_(result);
}
return result;
} else {
TORCH_CHECK(false, MPS_ERROR_RUNTIME_TOO_LOW)
}
Expand Down
18 changes: 16 additions & 2 deletions aten/src/ATen/native/TensorFactories.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,12 @@ Tensor polar(const Tensor& abs, const Tensor& angle) {
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ empty ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Tensor empty_cpu(IntArrayRef size, c10::optional<ScalarType> dtype_opt, c10::optional<Layout> layout_opt,
c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt, c10::optional<c10::MemoryFormat> memory_format_opt) {
return at::detail::empty_cpu(size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt);
Tensor result = at::detail::empty_cpu(size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt);
// See Note [Enabling Deterministic Operations]
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
fill_empty_deterministic_(result);
}
return result;
}

Tensor empty_names(
Expand Down Expand Up @@ -320,7 +325,12 @@ Tensor empty_permuted_symint(SymIntArrayRef size, IntArrayRef physical_layout, c

Tensor empty_strided_cpu(IntArrayRef size, IntArrayRef stride, c10::optional<ScalarType> dtype_opt,
c10::optional<Layout> layout_opt, c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt) {
return at::detail::empty_strided_cpu(size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt);
Tensor result = at::detail::empty_strided_cpu(size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt);
// See Note [Enabling Deterministic Operations]
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
fill_empty_deterministic_(result);
}
return result;
}

Tensor& empty_out(IntArrayRef size,
Expand All @@ -337,6 +347,10 @@ Tensor& empty_out(IntArrayRef size,
} else {
result.resize_(size);
}
// See Note [Enabling Deterministic Operations]
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
fill_empty_deterministic_(result);
}
return result;
}

Expand Down
15 changes: 13 additions & 2 deletions aten/src/ATen/native/cuda/TensorFactories.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <ATen/native/TensorFactories.h>
#include <c10/util/accumulate.h>
#include <c10/util/Exception.h>
#include <ATen/native/cuda/Loops.cuh>

#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
Expand Down Expand Up @@ -51,7 +52,12 @@ Tensor& eye_out_cuda(int64_t n, int64_t m, Tensor& result) {
}

Tensor empty_cuda(IntArrayRef size, c10::optional<ScalarType> dtype_opt, c10::optional<Layout> layout_opt, c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt, c10::optional<c10::MemoryFormat> memory_format_opt) {
return at::detail::empty_cuda(size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt);
Tensor result = at::detail::empty_cuda(size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt);
// See Note [Enabling Deterministic Operations]
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
fill_empty_deterministic_(result);
}
return result;
}

Tensor _efficientzerotensor_cuda(IntArrayRef size,
Expand All @@ -72,7 +78,12 @@ Tensor _efficientzerotensor_cuda(IntArrayRef size,


Tensor empty_strided_cuda(IntArrayRef size, IntArrayRef stride, c10::optional<ScalarType> dtype_opt, c10::optional<Layout> layout_opt, c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt) {
return at::detail::empty_strided_cuda(size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt);
Tensor result = at::detail::empty_strided_cuda(size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt);
// See Note [Enabling Deterministic Operations]
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
fill_empty_deterministic_(result);
}
return result;
}

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triangle ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
28 changes: 28 additions & 0 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1342,6 +1342,34 @@ def test_deterministic_resize(self, device, dtype):
else:
self.assertEqual(old_tensor, new_tensor)

# When deterministic algorithms are enabled, `torch.empty` should fill floating
# point tensors with NaN and integer tensors with MAX_INT
@skipXLA
@skipIfTorchInductor("aot-autograd issue")
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
def test_deterministic_empty(self, device, dtype):
gen_fns = [
lambda: torch.empty(10, 9, device=device, dtype=dtype),
lambda: torch.empty(10, 9, out=torch.zeros(1, device=device, dtype=dtype)),
lambda: torch.empty_like(torch.zeros(10, 9, device=device, dtype=dtype)),
lambda: torch.empty_like(torch.zeros(10, 9, device=device, dtype=dtype), memory_format=torch.contiguous_format),
lambda: torch.empty_strided((10, 9), (1, 5), device=device, dtype=dtype),
lambda: torch.empty_permuted((2, 3, 5), (1, 0, 2), device=device, dtype=dtype),
]

for gen_fn in gen_fns:
with DeterministicGuard(True):
res = gen_fn()

if dtype.is_floating_point or dtype.is_complex:
self.assertTrue(res.isnan().all())
else:
if dtype == torch.bool:
max_val = True
else:
max_val = torch.iinfo(dtype).max
self.assertTrue(res.eq(max_val).all())

# FIXME: update OpInfos to support "nondeterministic samples" and port these tests
# to that architecture
@skipIfMps
Expand Down
4 changes: 4 additions & 0 deletions torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,6 +704,10 @@ def use_deterministic_algorithms(mode, *, warn_only=False):
quantized, sets new elements to a known value. Floating point or
complex values are set to NaN. Integer values are set to the maximum
value.
* :func:`torch.empty`, :func:`torch.empty_like`, :func:`torch.empty_strided`,
and :func:`torch.empty_permuted` will fill the output tensor with a known
value. Floating point or complex dtype tensors are filled with NaN. Integer
dtype tensors are filled with the maximum value.
The following normally-nondeterministic operations will throw a
:class:`RuntimeError` when ``mode=True``:
Expand Down
28 changes: 28 additions & 0 deletions torch/_torch_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12277,6 +12277,13 @@ def merge_dicts(*dicts):
Returns a tensor filled with uninitialized data. The shape of the tensor is
defined by the variable argument :attr:`size`.
.. note::
If :func:`torch.use_deterministic_algorithms()` is set to ``True``, the
output tensor is initialized to prevent any possible nondeterministic
behavior from using the data as an input to an operation. Floating point
and complex tensors are filled with NaN, and integer tensors are filled
with the maximum value.
Args:
size (int...): a sequence of integers defining the shape of the output tensor.
Can be a variable number of arguments or a collection like a list or tuple.
Expand Down Expand Up @@ -12309,6 +12316,13 @@ def merge_dicts(*dicts):
``torch.empty_like(input)`` is equivalent to
``torch.empty(input.size(), dtype=input.dtype, layout=input.layout, device=input.device)``.
.. note::
If :func:`torch.use_deterministic_algorithms()` is set to ``True``, the
output tensor is initialized to prevent any possible nondeterministic
behavior from using the data as an input to an operation. Floating point
and complex tensors are filled with NaN, and integer tensors are filled
with the maximum value.
Args:
{input}
Expand Down Expand Up @@ -12341,6 +12355,13 @@ def merge_dicts(*dicts):
If the constructed tensor is "overlapped" (with multiple indices referring to the same element
in memory) its behavior is undefined.
.. note::
If :func:`torch.use_deterministic_algorithms()` is set to ``True``, the
output tensor is initialized to prevent any possible nondeterministic
behavior from using the data as an input to an operation. Floating point
and complex tensors are filled with NaN, and integer tensors are filled
with the maximum value.
Args:
size (tuple of int): the shape of the output tensor
stride (tuple of int): the strides of the output tensor
Expand Down Expand Up @@ -12386,6 +12407,13 @@ def merge_dicts(*dicts):
tensor with no overlaps. If possible, prefer using this function over
:func:`torch.empty_strided` or manual use of :func:`torch.as_strided`.
.. note::
If :func:`torch.use_deterministic_algorithms()` is set to ``True``, the
output tensor is initialized to prevent any possible nondeterministic
behavior from using the data as an input to an operation. Floating point
and complex tensors are filled with NaN, and integer tensors are filled
with the maximum value.
Args:
size (tuple of int): the shape of the output tensor
physical_layout (tuple of int): the ordering of dimensions physically in memory
Expand Down

0 comments on commit f987d11

Please sign in to comment.