Skip to content

Commit

Permalink
Use TORCH_SYM_CHECK for check_size_nonnegative on SymIntArrayRef
Browse files Browse the repository at this point in the history
See #106788 for context.

I think I don't actually need this for anything real, but this is pretty
mild so might as well.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

[ghstack-poisoned]
  • Loading branch information
ezyang committed Aug 23, 2023
1 parent 207b06d commit 812d6c1
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions aten/src/ATen/EmptyTensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
namespace at {
namespace detail {

template <class ArrayRefType>
inline void check_size_nonnegative(ArrayRefType size) {
inline void check_size_nonnegative(ArrayRef<int64_t> size) {
for (const auto& x : size) {
TORCH_CHECK(
x >= 0,
Expand All @@ -16,6 +15,17 @@ inline void check_size_nonnegative(ArrayRefType size) {
}
}

inline void check_size_nonnegative(ArrayRef<c10::SymInt> size) {
for (const auto& x : size) {
TORCH_SYM_CHECK(
x.sym_ge(0),
"Trying to create tensor with negative dimension ",
x,
": ",
size);
}
}

TORCH_API size_t computeStorageNbytesContiguous(
IntArrayRef sizes,
size_t itemsize,
Expand Down

0 comments on commit 812d6c1

Please sign in to comment.