Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use expect_true to make split with unbacked sizes work. #106788

Closed
wants to merge 8 commits into from
29 changes: 20 additions & 9 deletions aten/src/ATen/native/TensorShape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1308,9 +1308,12 @@ Tensor& narrow_copy_dense_cpu_out(

// wrap start and do bound check
const auto cur_size = self_sizes[dim];
if (start != cur_size && start < 0) { // start being the end is valid, but
// not a valid dim specification.
start = at::maybe_wrap_dim(start, cur_size);
TORCH_CHECK_INDEX(
-cur_size <= start && start <= cur_size,
"start out of range (expected to be in range of [", -cur_size, ", ", cur_size, "], but got ", start, ")"
)
if (start < 0) {
start = start + cur_size;
}
TORCH_CHECK(
length >= 0 && start <= cur_size - length,
Expand Down Expand Up @@ -1374,8 +1377,12 @@ Tensor narrow(const Tensor& self, int64_t dim, int64_t start, int64_t length) {
TORCH_CHECK(self.dim() > 0, "narrow() cannot be applied to a 0-dim tensor.");
TORCH_CHECK(length >= 0, "narrow(): length must be non-negative.");
auto cur_size = self.size(dim);
if (start != cur_size) { // start being the end is valid, but not a valid dim specification.
start = maybe_wrap_dim(start, cur_size);
TORCH_CHECK_INDEX(
-cur_size <= start && start <= cur_size,
"start out of range (expected to be in range of [", -cur_size, ", ", cur_size, "], but got ", start, ")"
)
if (start < 0) {
start = start + cur_size;
}
TORCH_CHECK(start <= cur_size - length,
"start (", start, ") + length (", length, ") exceeds dimension size (", cur_size, ").");
Expand All @@ -1384,12 +1391,16 @@ Tensor narrow(const Tensor& self, int64_t dim, int64_t start, int64_t length) {

Tensor narrow_symint(const Tensor& self, int64_t dim, SymInt start, SymInt length) {
TORCH_CHECK(self.dim() > 0, "narrow() cannot be applied to a 0-dim tensor.");
TORCH_CHECK(length >= 0, "narrow(): length must be non-negative.");
TORCH_SYM_CHECK(length.sym_ge(0), "narrow(): length must be non-negative.");
auto cur_size = self.sym_size(dim);
if (start != cur_size) { // start being the end is valid, but not a valid dim specification.
start = maybe_wrap_dim(start, cur_size);
TORCH_CHECK_INDEX(
((-cur_size).sym_le(start).sym_and(start.sym_le(cur_size))).expect_true(__FILE__, __LINE__),
lezcano marked this conversation as resolved.
Show resolved Hide resolved
"start out of range (expected to be in range of [", -cur_size, ", ", cur_size, "], but got ", start, ")"
)
if (start < 0) {
start = start + cur_size;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The start != cur_size test cannot be sym'ified since one branch is not an error condition. What I did here was inlined maybe_wrap_dim here, which conventionally tests -cur_size <= start < cur_size; so you can make it work without extra branching by just changing this condition to -cur_size <= start <= cur_size.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can split out the inlining into its own PR if people would prefer.

}
TORCH_CHECK(start <= cur_size - length,
TORCH_SYM_CHECK(start.sym_le(cur_size - length),
"start (", start, ") + length (", length, ") exceeds dimension size (", cur_size, ").");
return at::slice_symint(self, dim, start, start + length, 1);
}
Expand Down
15 changes: 15 additions & 0 deletions test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6296,6 +6296,21 @@ def mapper(x):
self.assertEqual(counter.frame_count, 1)
self.assertEqual(counter.op_count, 9)

@torch._dynamo.config.patch(
capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True
)
def test_unbacked_symint(self):
from torch._export.constraints import constrain_as_size

@torch.compile(backend="eager")
def f(lengths, values):
sizes = lengths.tolist()
for s in sizes:
constrain_as_size(s, min=2, max=100)
return torch.split(values, sizes)

f(torch.tensor([2, 3, 4]), torch.randn(9))

def test_simple_set_usage(self):
def foo(x, y):
setty = {x, y}
Expand Down
26 changes: 26 additions & 0 deletions test/test_proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1174,6 +1174,32 @@ def forward(self, a_1):
empty = torch.ops.aten.empty.memory_format([add], device = device(type='cpu'), pin_memory = False); add = None
return empty""")

def test_split_unbacked_sizes(self):
def f(lengths, values):
# tolist not directly supported atm
sizes = [lengths[i].item() for i in range(lengths.size(0))]
for s in sizes:
constrain_range(s, min=2) # TODO: better as constrain_as_size
return torch.split(values, sizes)

r = str(make_fx(f, tracing_mode="symbolic")(
torch.tensor([2, 3, 4]),
torch.randn(9)
).code).strip()
self.assertExpectedInline(r, """\
def forward(self, lengths_1, values_1):
select = torch.ops.aten.select.int(lengths_1, 0, 0)
_local_scalar_dense = torch.ops.aten._local_scalar_dense.default(select); select = None
select_1 = torch.ops.aten.select.int(lengths_1, 0, 1)
_local_scalar_dense_1 = torch.ops.aten._local_scalar_dense.default(select_1); select_1 = None
select_2 = torch.ops.aten.select.int(lengths_1, 0, 2); lengths_1 = None
_local_scalar_dense_2 = torch.ops.aten._local_scalar_dense.default(select_2); select_2 = None
split_with_sizes = torch.ops.aten.split_with_sizes.default(values_1, [_local_scalar_dense, _local_scalar_dense_1, _local_scalar_dense_2]); values_1 = _local_scalar_dense = _local_scalar_dense_1 = _local_scalar_dense_2 = None
getitem = split_with_sizes[0]
getitem_1 = split_with_sizes[1]
getitem_2 = split_with_sizes[2]; split_with_sizes = None
return (getitem, getitem_1, getitem_2)""") # noqa: B950

def test_invalidate_nonzero(self):
ok = False

Expand Down
2 changes: 1 addition & 1 deletion torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,7 +948,7 @@ def _check_with(error_type, cond: Union[builtins.bool, SymBool], message: Callab
if not isinstance(cond, (builtins.bool, torch.SymBool)):
raise TypeError(f'cond must be a bool, but got {type(cond)}')

if cond:
if torch.fx.experimental.symbolic_shapes.expect_true(cond):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This turns on expect_true for all our decomps/meta functions, pretty nice. (I couldn't do the same trick in C++; TORCH_CHECK is a very delicate macro and it was hard to insert expect_true without breaking some sites, and there's also the problem that operator< and friends actually return bool not SymBool).

Copy link
Collaborator

@lezcano lezcano Aug 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nb. You can do a horrible, horrible thing, and locally overwrite the meaning of < inside TORCH_CHECK via a macro so that you automagically go from using operator< to lt.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, this line is fantastic. Gotta love having just one entrypoint for a given thing.

return

# error_type must be a subclass of Exception and not subclass of Warning
Expand Down
22 changes: 15 additions & 7 deletions torch/_decomp/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
_safe_copy_out,
out_wrapper,
)
from torch.fx.experimental.symbolic_shapes import guard_int
from torch.fx.experimental.symbolic_shapes import expect_true, guard_int
from torch.utils._pytree import tree_flatten, tree_map

DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined]
Expand Down Expand Up @@ -727,12 +727,12 @@ def slice_forward(

if start_val < 0:
start_val = 0
elif start_val >= sizes[dim]:
elif start_val > sizes[dim]:
start_val = sizes[dim]

if end_val < start_val:
end_val = start_val
elif end_val >= sizes[dim]:
elif end_val > sizes[dim]:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two changes don't actually change semantics, but they're pretty important: frequently, we will know that end_val is <= sizes[dim], but we don't know if it == sizes[dim] or not. By branching only if it is truly out of bounds, we can statically determine which branch we go down. This is enough for split.

end_val = sizes[dim]

storage_offset = self.storage_offset() + start_val * strides[dim]
Expand Down Expand Up @@ -1182,15 +1182,23 @@ def prod(x: List[int]):
def split_with_sizes(
self: Tensor, split_sizes: List[int], dim: int = 0
) -> List[Tensor]:
if sum(split_sizes) != self.shape[dim]:
raise ValueError(
"Split sizes don't add up to the tensor's size in the given dimension"
)
torch._check_with(
ValueError,
sum(split_sizes) == self.shape[dim],
lambda: "Split sizes don't add up to the tensor's size in the given dimension",
)
num_splits = len(split_sizes)
splits = []
start_idx = 0
for i in range(num_splits):
length = split_sizes[i]
torch._check(
length >= 0,
lambda: "split_with_sizes expects split_sizes have only non-negative entries",
)
# We know this is true thanks to the sum, but this assertion helps
# out our internal reasoning
expect_true(start_idx + length <= self.shape[dim])
splits.append(self.narrow(dim, start_idx, length))
start_idx += length
return splits
Expand Down
15 changes: 7 additions & 8 deletions torch/_refs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2923,14 +2923,13 @@ def narrow(
torch._check(length >= 0, lambda: "narrow(): length must be non-negative.")
dim = utils.canonicalize_dim(a.ndim, dim)
dim_length = a.size(dim)
# Start being the end is usually invalid since it's out of bounds. So it's
# not allowed by canonicalize_dim. But for narrow it's valid as long as
# the length is 0, which is handled by the check below.
if start != dim_length:
# Negative start means indexing from the end of dim.
# Note: a dimension isn't being canonicalized here, this reuses
# canonicalize_dim because the semantics are similar.
start = utils.canonicalize_dim(dim_length, start) # type: ignore[arg-type]
torch._check_with(
IndexError,
-dim_length <= start and start <= dim_length, # type: ignore[arg-type]
lambda: f"start out of range (expected to be in range of [{-dim_length}, {dim_length}], but got {start})",
)
if start < 0:
BowenBao marked this conversation as resolved.
Show resolved Hide resolved
start = start + dim_length
torch._check(
start <= dim_length - length, # type: ignore[arg-type]
lambda: f"start ({start}) + length ({length}) exceeds dimension size ({dim_length}).",
Expand Down
18 changes: 8 additions & 10 deletions torch/csrc/jit/python/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1210,27 +1210,25 @@ void initJITBindings(PyObject* module) {
SYMNODE_SIZES_STRIDES(is_channels_last_strides_2d)
SYMNODE_SIZES_STRIDES(is_channels_last_strides_3d)
SYMNODE_SIZES_STRIDES(is_non_overlapping_and_dense)
// Intentionally don't set file line, as the
// Python backtrace matters more here
.def(
"guard_int",
[](c10::SymNode a) {
return a->guard_int(nullptr, 0);
[](c10::SymNode a, const char* file, int64_t line) {
return a->guard_int(file, line);
})
.def(
"guard_bool",
[](c10::SymNode a) {
return a->guard_bool(nullptr, 0);
[](c10::SymNode a, const char* file, int64_t line) {
return a->guard_bool(file, line);
})
.def(
"guard_float",
[](c10::SymNode a) {
return a->guard_float(nullptr, 0);
[](c10::SymNode a, const char* file, int64_t line) {
return a->guard_float(file, line);
})
.def(
"expect_true",
[](c10::SymNode a) {
return a->expect_true(nullptr, 0);
[](c10::SymNode a, const char* file, int64_t line) {
return a->expect_true(file, line);
})
.def(
"has_hint",
Expand Down
14 changes: 4 additions & 10 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5013,19 +5013,13 @@ def error_inputs_narrow_narrow_copy(op_info, device, *, is_narrow, is_ref):
error_regex=r"Dimension out of range \(expected to be in range of \[-3, 2\], but got -4\)")

# out of bounds start
if not is_narrow and not is_ref and torch.device(device).type == 'cpu':
# narrow_copy_dense_cpu_out
yield ErrorInput(SampleInput(make_arg((L, M, S)), 1, M + 1, 0),
error_type=RuntimeError,
error_regex=r"start \(11\) \+ length \(0\) exceeds dimension size \(10\)\.")
else:
yield ErrorInput(SampleInput(make_arg((L, M, S)), 1, M + 1, 0),
error_type=IndexError,
error_regex=r"Dimension out of range \(expected to be in range of \[-10, 9\], but got 11\)")
yield ErrorInput(SampleInput(make_arg((L, M, S)), 1, M + 1, 0),
error_type=IndexError,
error_regex=r"start out of range \(expected to be in range of \[-10, 10\], but got 11\)")
# out of bounds start (negative)
yield ErrorInput(SampleInput(make_arg((L, M, S)), 1, -M - 1, 0),
error_type=IndexError,
error_regex=r"Dimension out of range \(expected to be in range of \[-10, 9\], but got -11\)")
error_regex=r"start out of range \(expected to be in range of \[-10, 10\], but got -11\)")

# out of bounds length
yield ErrorInput(SampleInput(make_arg((S, L, M)), 2, 0, M + 1),
Expand Down