Skip to content

Commit

Permalink
Add overflow check for stride calculation (#94900)
Browse files Browse the repository at this point in the history
Fixes #94120 and #94128.

Pull Request resolved: #94900
Approved by: https://github.com/ezyang, https://github.com/jgong5
  • Loading branch information
CaoE authored and pytorchmergebot committed Apr 9, 2023
1 parent 3925f6e commit fdb04c6
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 7 deletions.
19 changes: 12 additions & 7 deletions c10/core/TensorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1754,7 +1754,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
new_stride.size(),
")");
const auto new_dim = new_size.size();

bool overflowed = false;
sizes_and_strides_.set_sizes(new_size);

if (new_dim > 0) {
Expand All @@ -1769,15 +1769,17 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
sizes_and_strides_.stride_at_unchecked(dim) = 1;
} else {
// Keep stride monotonically increasing to match NumPy.
sizes_and_strides_.stride_at_unchecked(dim) =
overflowed |= c10::mul_overflows(
sizes_and_strides_.stride_at_unchecked(dim + 1),
std::max<int64_t>(
sizes_and_strides_.size_at_unchecked(dim + 1), 1) *
sizes_and_strides_.stride_at_unchecked(dim + 1);
sizes_and_strides_.size_at_unchecked(dim + 1), 1),
std::addressof(sizes_and_strides_.stride_at_unchecked(dim)));
}
}
if (dim == 0)
break;
}
TORCH_CHECK(!overflowed, "Stride calculation overflowed");
}

refresh_numel();
Expand Down Expand Up @@ -2274,14 +2276,17 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
const auto dim_ = dim();
sizes_and_strides_.resize(dim_);
if (dim_ > 0) {
bool overflowed = false;
const auto last_idx = dim_ - 1;
sizes_and_strides_.stride_at_unchecked(last_idx) = 1;
for (auto i = last_idx - 1; i >= 0; --i) {
sizes_and_strides_.stride_at_unchecked(i) =
sizes_and_strides_.stride_at_unchecked(i + 1) *
overflowed |= c10::mul_overflows(
sizes_and_strides_.stride_at_unchecked(i + 1),
std::max<int64_t>(
sizes_and_strides_.size_at_unchecked(i + 1), 1);
sizes_and_strides_.size_at_unchecked(i + 1), 1),
std::addressof(sizes_and_strides_.stride_at_unchecked(i)));
}
TORCH_CHECK(!overflowed, "Stride calculation overflowed");
}
break;
}
Expand Down
13 changes: 13 additions & 0 deletions c10/util/safe_numerics.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,19 @@ C10_ALWAYS_INLINE bool mul_overflows(uint64_t a, uint64_t b, uint64_t* out) {
#endif
}

C10_ALWAYS_INLINE bool mul_overflows(int64_t a, int64_t b, int64_t* out) {
#if C10_HAS_BUILTIN_OVERFLOW()
return __builtin_mul_overflow(a, b, out);
#else
volatile int64_t tmp = a * b;
*out = tmp;
if (a == 0 || b == 0) {
return false;
}
return !(a == tmp / b);
#endif
}

template <typename It>
bool safe_multiplies_u64(It first, It last, uint64_t* out) {
#if C10_HAS_BUILTIN_OVERFLOW()
Expand Down
2 changes: 2 additions & 0 deletions test/test_tensor_creation_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2504,6 +2504,8 @@ def test_empty_overflow(self, device):
torch.empty([8, 8, 2**29, 2**29], dtype=torch.float64)
with self.assertRaisesRegex(RuntimeError, 'Storage size calculation overflowed'):
torch.empty_strided([8, 8], [2**61, 1], dtype=torch.float64)
with self.assertRaisesRegex(RuntimeError, 'Stride calculation overflowed'):
torch.empty([0, 4, 2305843009213693952], dtype=torch.float32)

def test_eye(self, device):
for dtype in all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16):
Expand Down
2 changes: 2 additions & 0 deletions test/test_view_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1811,6 +1811,8 @@ def test_resize_overflow(self, device):
x.resize_([2, 4, 2**29, 2**29])
with self.assertRaisesRegex(RuntimeError, 'overflow'):
x.resize_([8, 8, 2**29, 2**29])
with self.assertRaisesRegex(RuntimeError, 'Stride calculation overflowed'):
x.resize_([0, 4, 2305843009213693952])

def test_view_all_dtypes_and_devices(self, device):
for dt in all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool):
Expand Down

0 comments on commit fdb04c6

Please sign in to comment.