Skip to content

Commit

Permalink
Use expect_true to make split with unbacked sizes work.
Browse files Browse the repository at this point in the history
This pattern shows up in torchrec KeyedJaggedTensor.  Most
of the change in this PR is mechanical: whenever we failed
an unbacked symint test due to just error checking, replace the
conditional with something that calls expect_true (e.g.,
torch._check or TORCH_SYM_CHECK).

Some of the changes are a bit more nuanced, I've commented on the PR
accordingly.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

ghstack-source-id: b1ba0de493b3141f54a844f683cdf94620d721ab
Pull Request resolved: #106788
  • Loading branch information
ezyang committed Aug 15, 2023
1 parent a55f820 commit d183e6b
Show file tree
Hide file tree
Showing 8 changed files with 99 additions and 45 deletions.
29 changes: 20 additions & 9 deletions aten/src/ATen/native/TensorShape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1308,9 +1308,12 @@ Tensor& narrow_copy_dense_cpu_out(

// wrap start and do bound check
const auto cur_size = self_sizes[dim];
if (start != cur_size && start < 0) { // start being the end is valid, but
// not a valid dim specification.
start = at::maybe_wrap_dim(start, cur_size);
TORCH_CHECK_INDEX(
-cur_size <= start && start <= cur_size,
"start out of range (expected to be in range of [", -cur_size, ", ", cur_size, "], but got ", start, ")"
)
if (start < 0) {
start = start + cur_size;
}
TORCH_CHECK(
length >= 0 && start <= cur_size - length,
Expand Down Expand Up @@ -1374,8 +1377,12 @@ Tensor narrow(const Tensor& self, int64_t dim, int64_t start, int64_t length) {
TORCH_CHECK(self.dim() > 0, "narrow() cannot be applied to a 0-dim tensor.");
TORCH_CHECK(length >= 0, "narrow(): length must be non-negative.");
auto cur_size = self.size(dim);
if (start != cur_size) { // start being the end is valid, but not a valid dim specification.
start = maybe_wrap_dim(start, cur_size);
TORCH_CHECK_INDEX(
-cur_size <= start && start <= cur_size,
"start out of range (expected to be in range of [", -cur_size, ", ", cur_size, "], but got ", start, ")"
)
if (start < 0) {
start = start + cur_size;
}
TORCH_CHECK(start <= cur_size - length,
"start (", start, ") + length (", length, ") exceeds dimension size (", cur_size, ").");
Expand All @@ -1384,12 +1391,16 @@ Tensor narrow(const Tensor& self, int64_t dim, int64_t start, int64_t length) {

Tensor narrow_symint(const Tensor& self, int64_t dim, SymInt start, SymInt length) {
TORCH_CHECK(self.dim() > 0, "narrow() cannot be applied to a 0-dim tensor.");
TORCH_CHECK(length >= 0, "narrow(): length must be non-negative.");
TORCH_SYM_CHECK(length.sym_ge(0), "narrow(): length must be non-negative.");
auto cur_size = self.sym_size(dim);
if (start != cur_size) { // start being the end is valid, but not a valid dim specification.
start = maybe_wrap_dim(start, cur_size);
TORCH_CHECK_INDEX(
((-cur_size).sym_le(start).sym_and(start.sym_le(cur_size))).expect_true(__FILE__, __LINE__),
"start out of range (expected to be in range of [", -cur_size, ", ", cur_size, "], but got ", start, ")"
)
if (start < 0) {
start = start + cur_size;
}
TORCH_CHECK(start <= cur_size - length,
TORCH_SYM_CHECK(start.sym_le(cur_size - length),
"start (", start, ") + length (", length, ") exceeds dimension size (", cur_size, ").");
return at::slice_symint(self, dim, start, start + length, 1);
}
Expand Down
15 changes: 15 additions & 0 deletions test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6296,6 +6296,21 @@ def mapper(x):
self.assertEqual(counter.frame_count, 1)
self.assertEqual(counter.op_count, 9)

@torch._dynamo.config.patch(
capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True
)
def test_unbacked_symint(self):
from torch._export.constraints import constrain_as_size

@torch.compile(backend="eager")
def f(lengths, values):
sizes = lengths.tolist()
for s in sizes:
constrain_as_size(s, min=2, max=100)
return torch.split(values, sizes)

f(torch.tensor([2, 3, 4]), torch.randn(9))

def test_simple_set_usage(self):
def foo(x, y):
setty = {x, y}
Expand Down
29 changes: 29 additions & 0 deletions test/test_proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1176,6 +1176,35 @@ def forward(self, a_1):
empty = torch.ops.aten.empty.memory_format([add], device = device(type='cpu'), pin_memory = False); add = None
return empty""")

def test_split_unbacked_sizes(self):
def f(lengths, values):
# tolist not directly supported atm
sizes = [lengths[i].item() for i in range(lengths.size(0))]
for s in sizes:
constrain_as_size(s)
return torch.split(values, sizes)

r = str(make_fx(f, tracing_mode="symbolic")(
torch.tensor([2, 3, 4]),
torch.randn(9)
).code).strip()
self.assertExpectedInline(r, """\
def forward(self, lengths_1, values_1):
select = torch.ops.aten.select.int(lengths_1, 0, 0)
_local_scalar_dense = torch.ops.aten._local_scalar_dense.default(select); select = None
select_1 = torch.ops.aten.select.int(lengths_1, 0, 1)
_local_scalar_dense_1 = torch.ops.aten._local_scalar_dense.default(select_1); select_1 = None
select_2 = torch.ops.aten.select.int(lengths_1, 0, 2); lengths_1 = None
_local_scalar_dense_2 = torch.ops.aten._local_scalar_dense.default(select_2); select_2 = None
sym_constrain_range_for_size = torch.ops.aten.sym_constrain_range_for_size.default(_local_scalar_dense, min = None, max = None)
sym_constrain_range_for_size_1 = torch.ops.aten.sym_constrain_range_for_size.default(_local_scalar_dense_1, min = None, max = None)
sym_constrain_range_for_size_2 = torch.ops.aten.sym_constrain_range_for_size.default(_local_scalar_dense_2, min = None, max = None)
split_with_sizes = torch.ops.aten.split_with_sizes.default(values_1, [_local_scalar_dense, _local_scalar_dense_1, _local_scalar_dense_2]); values_1 = _local_scalar_dense = _local_scalar_dense_1 = _local_scalar_dense_2 = None
getitem = split_with_sizes[0]
getitem_1 = split_with_sizes[1]
getitem_2 = split_with_sizes[2]; split_with_sizes = None
return (getitem, getitem_1, getitem_2)""") # noqa: B950

def test_invalidate_nonzero(self):
ok = False

Expand Down
2 changes: 1 addition & 1 deletion torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,7 +948,7 @@ def _check_with(error_type, cond: Union[builtins.bool, SymBool], message: Callab
if not isinstance(cond, (builtins.bool, torch.SymBool)):
raise TypeError(f'cond must be a bool, but got {type(cond)}')

if cond:
if torch.fx.experimental.symbolic_shapes.expect_true(cond):
return

# error_type must be a subclass of Exception and not subclass of Warning
Expand Down
22 changes: 15 additions & 7 deletions torch/_decomp/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
_safe_copy_out,
out_wrapper,
)
from torch.fx.experimental.symbolic_shapes import guard_int
from torch.fx.experimental.symbolic_shapes import expect_true, guard_int
from torch.utils._pytree import tree_flatten, tree_map

DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined]
Expand Down Expand Up @@ -717,12 +717,12 @@ def slice_forward(

if start_val < 0:
start_val = 0
elif start_val >= sizes[dim]:
elif start_val > sizes[dim]:
start_val = sizes[dim]

if end_val < start_val:
end_val = start_val
elif end_val >= sizes[dim]:
elif end_val > sizes[dim]:
end_val = sizes[dim]

storage_offset = self.storage_offset() + start_val * strides[dim]
Expand Down Expand Up @@ -1172,15 +1172,23 @@ def prod(x: List[int]):
def split_with_sizes(
self: Tensor, split_sizes: List[int], dim: int = 0
) -> List[Tensor]:
if sum(split_sizes) != self.shape[dim]:
raise ValueError(
"Split sizes don't add up to the tensor's size in the given dimension"
)
torch._check_with(
ValueError,
sum(split_sizes) == self.shape[dim],
lambda: "Split sizes don't add up to the tensor's size in the given dimension",
)
num_splits = len(split_sizes)
splits = []
start_idx = 0
for i in range(num_splits):
length = split_sizes[i]
torch._check(
length >= 0,
lambda: "split_with_sizes expects split_sizes have only non-negative entries",
)
# We know this is true thanks to the sum, but this assertion helps
# out our internal reasoning
expect_true(start_idx + length <= self.shape[dim])
splits.append(self.narrow(dim, start_idx, length))
start_idx += length
return splits
Expand Down
15 changes: 7 additions & 8 deletions torch/_refs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2923,14 +2923,13 @@ def narrow(
torch._check(length >= 0, lambda: "narrow(): length must be non-negative.")
dim = utils.canonicalize_dim(a.ndim, dim)
dim_length = a.size(dim)
# Start being the end is usually invalid since it's out of bounds. So it's
# not allowed by canonicalize_dim. But for narrow it's valid as long as
# the length is 0, which is handled by the check below.
if start != dim_length:
# Negative start means indexing from the end of dim.
# Note: a dimension isn't being canonicalized here, this reuses
# canonicalize_dim because the semantics are similar.
start = utils.canonicalize_dim(dim_length, start) # type: ignore[arg-type]
torch._check_with(
IndexError,
-dim_length <= start and start <= dim_length, # type: ignore[arg-type]
lambda: f"start out of range (expected to be in range of [{-dim_length}, {dim_length}], but got {start})",
)
if start < 0:
start = start + dim_length
torch._check(
start <= dim_length - length, # type: ignore[arg-type]
lambda: f"start ({start}) + length ({length}) exceeds dimension size ({dim_length}).",
Expand Down
18 changes: 8 additions & 10 deletions torch/csrc/jit/python/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1210,27 +1210,25 @@ void initJITBindings(PyObject* module) {
SYMNODE_SIZES_STRIDES(is_channels_last_strides_2d)
SYMNODE_SIZES_STRIDES(is_channels_last_strides_3d)
SYMNODE_SIZES_STRIDES(is_non_overlapping_and_dense)
// Intentionally don't set file line, as the
// Python backtrace matters more here
.def(
"guard_int",
[](c10::SymNode a) {
return a->guard_int(nullptr, 0);
[](c10::SymNode a, const char* file, int64_t line) {
return a->guard_int(file, line);
})
.def(
"guard_bool",
[](c10::SymNode a) {
return a->guard_bool(nullptr, 0);
[](c10::SymNode a, const char* file, int64_t line) {
return a->guard_bool(file, line);
})
.def(
"guard_float",
[](c10::SymNode a) {
return a->guard_float(nullptr, 0);
[](c10::SymNode a, const char* file, int64_t line) {
return a->guard_float(file, line);
})
.def(
"expect_true",
[](c10::SymNode a) {
return a->expect_true(nullptr, 0);
[](c10::SymNode a, const char* file, int64_t line) {
return a->expect_true(file, line);
})
.def(
"has_hint",
Expand Down
14 changes: 4 additions & 10 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5013,19 +5013,13 @@ def error_inputs_narrow_narrow_copy(op_info, device, *, is_narrow, is_ref):
error_regex=r"Dimension out of range \(expected to be in range of \[-3, 2\], but got -4\)")

# out of bounds start
if not is_narrow and not is_ref and torch.device(device).type == 'cpu':
# narrow_copy_dense_cpu_out
yield ErrorInput(SampleInput(make_arg((L, M, S)), 1, M + 1, 0),
error_type=RuntimeError,
error_regex=r"start \(11\) \+ length \(0\) exceeds dimension size \(10\)\.")
else:
yield ErrorInput(SampleInput(make_arg((L, M, S)), 1, M + 1, 0),
error_type=IndexError,
error_regex=r"Dimension out of range \(expected to be in range of \[-10, 9\], but got 11\)")
yield ErrorInput(SampleInput(make_arg((L, M, S)), 1, M + 1, 0),
error_type=IndexError,
error_regex=r"start out of range \(expected to be in range of \[-10, 10\], but got 11\)")
# out of bounds start (negative)
yield ErrorInput(SampleInput(make_arg((L, M, S)), 1, -M - 1, 0),
error_type=IndexError,
error_regex=r"Dimension out of range \(expected to be in range of \[-10, 9\], but got -11\)")
error_regex=r"start out of range \(expected to be in range of \[-10, 10\], but got -11\)")

# out of bounds length
yield ErrorInput(SampleInput(make_arg((S, L, M)), 2, 0, M + 1),
Expand Down

0 comments on commit d183e6b

Please sign in to comment.