Skip to content

Commit

Permalink
Use expect_true to make split with unbacked sizes work. (#106788)
Browse files Browse the repository at this point in the history
This pattern shows up in torchrec KeyedJaggedTensor.  Most
of the change in this PR is mechanical: whenever we failed
an unbacked symint test due to just error checking, replace the
conditional with something that calls expect_true (e.g.,
torch._check or TORCH_SYM_CHECK).

Some of the changes are a bit more nuanced, I've commented on the PR
accordingly.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: #106788
Approved by: https://github.com/lezcano
ghstack dependencies: #106720
  • Loading branch information
ezyang authored and pytorchmergebot committed Aug 15, 2023
1 parent e1ee10e commit 5673c08
Show file tree
Hide file tree
Showing 8 changed files with 99 additions and 45 deletions.
29 changes: 20 additions & 9 deletions aten/src/ATen/native/TensorShape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1308,9 +1308,12 @@ Tensor& narrow_copy_dense_cpu_out(

// wrap start and do bound check
const auto cur_size = self_sizes[dim];
if (start != cur_size && start < 0) { // start being the end is valid, but
// not a valid dim specification.
start = at::maybe_wrap_dim(start, cur_size);
TORCH_CHECK_INDEX(
-cur_size <= start && start <= cur_size,
"start out of range (expected to be in range of [", -cur_size, ", ", cur_size, "], but got ", start, ")"
)
if (start < 0) {
start = start + cur_size;
}
TORCH_CHECK(
length >= 0 && start <= cur_size - length,
Expand Down Expand Up @@ -1374,8 +1377,12 @@ Tensor narrow(const Tensor& self, int64_t dim, int64_t start, int64_t length) {
TORCH_CHECK(self.dim() > 0, "narrow() cannot be applied to a 0-dim tensor.");
TORCH_CHECK(length >= 0, "narrow(): length must be non-negative.");
auto cur_size = self.size(dim);
if (start != cur_size) { // start being the end is valid, but not a valid dim specification.
start = maybe_wrap_dim(start, cur_size);
TORCH_CHECK_INDEX(
-cur_size <= start && start <= cur_size,
"start out of range (expected to be in range of [", -cur_size, ", ", cur_size, "], but got ", start, ")"
)
if (start < 0) {
start = start + cur_size;
}
TORCH_CHECK(start <= cur_size - length,
"start (", start, ") + length (", length, ") exceeds dimension size (", cur_size, ").");
Expand All @@ -1384,12 +1391,16 @@ Tensor narrow(const Tensor& self, int64_t dim, int64_t start, int64_t length) {

Tensor narrow_symint(const Tensor& self, int64_t dim, SymInt start, SymInt length) {
TORCH_CHECK(self.dim() > 0, "narrow() cannot be applied to a 0-dim tensor.");
TORCH_CHECK(length >= 0, "narrow(): length must be non-negative.");
TORCH_SYM_CHECK(length.sym_ge(0), "narrow(): length must be non-negative.");
auto cur_size = self.sym_size(dim);
if (start != cur_size) { // start being the end is valid, but not a valid dim specification.
start = maybe_wrap_dim(start, cur_size);
TORCH_CHECK_INDEX(
((-cur_size).sym_le(start).sym_and(start.sym_le(cur_size))).expect_true(__FILE__, __LINE__),
"start out of range (expected to be in range of [", -cur_size, ", ", cur_size, "], but got ", start, ")"
)
if (start < 0) {
start = start + cur_size;
}
TORCH_CHECK(start <= cur_size - length,
TORCH_SYM_CHECK(start.sym_le(cur_size - length),
"start (", start, ") + length (", length, ") exceeds dimension size (", cur_size, ").");
return at::slice_symint(self, dim, start, start + length, 1);
}
Expand Down
15 changes: 15 additions & 0 deletions test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6296,6 +6296,21 @@ def mapper(x):
self.assertEqual(counter.frame_count, 1)
self.assertEqual(counter.op_count, 9)

@torch._dynamo.config.patch(
capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True
)
def test_unbacked_symint(self):
from torch._export.constraints import constrain_as_size

@torch.compile(backend="eager")
def f(lengths, values):
sizes = lengths.tolist()
for s in sizes:
constrain_as_size(s, min=2, max=100)
return torch.split(values, sizes)

f(torch.tensor([2, 3, 4]), torch.randn(9))

def test_simple_set_usage(self):
def foo(x, y):
setty = {x, y}
Expand Down
29 changes: 29 additions & 0 deletions test/test_proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1176,6 +1176,35 @@ def forward(self, a_1):
empty = torch.ops.aten.empty.memory_format([add], device = device(type='cpu'), pin_memory = False); add = None
return empty""")

def test_split_unbacked_sizes(self):
def f(lengths, values):
# tolist not directly supported atm
sizes = [lengths[i].item() for i in range(lengths.size(0))]
for s in sizes:
constrain_as_size(s)
return torch.split(values, sizes)

r = str(make_fx(f, tracing_mode="symbolic")(
torch.tensor([2, 3, 4]),
torch.randn(9)
).code).strip()
self.assertExpectedInline(r, """\
def forward(self, lengths_1, values_1):
select = torch.ops.aten.select.int(lengths_1, 0, 0)
_local_scalar_dense = torch.ops.aten._local_scalar_dense.default(select); select = None
select_1 = torch.ops.aten.select.int(lengths_1, 0, 1)
_local_scalar_dense_1 = torch.ops.aten._local_scalar_dense.default(select_1); select_1 = None
select_2 = torch.ops.aten.select.int(lengths_1, 0, 2); lengths_1 = None
_local_scalar_dense_2 = torch.ops.aten._local_scalar_dense.default(select_2); select_2 = None
sym_constrain_range_for_size = torch.ops.aten.sym_constrain_range_for_size.default(_local_scalar_dense, min = None, max = None)
sym_constrain_range_for_size_1 = torch.ops.aten.sym_constrain_range_for_size.default(_local_scalar_dense_1, min = None, max = None)
sym_constrain_range_for_size_2 = torch.ops.aten.sym_constrain_range_for_size.default(_local_scalar_dense_2, min = None, max = None)
split_with_sizes = torch.ops.aten.split_with_sizes.default(values_1, [_local_scalar_dense, _local_scalar_dense_1, _local_scalar_dense_2]); values_1 = _local_scalar_dense = _local_scalar_dense_1 = _local_scalar_dense_2 = None
getitem = split_with_sizes[0]
getitem_1 = split_with_sizes[1]
getitem_2 = split_with_sizes[2]; split_with_sizes = None
return (getitem, getitem_1, getitem_2)""") # noqa: B950

def test_invalidate_nonzero(self):
ok = False

Expand Down
2 changes: 1 addition & 1 deletion torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,7 +948,7 @@ def _check_with(error_type, cond: Union[builtins.bool, SymBool], message: Callab
if not isinstance(cond, (builtins.bool, torch.SymBool)):
raise TypeError(f'cond must be a bool, but got {type(cond)}')

if cond:
if torch.fx.experimental.symbolic_shapes.expect_true(cond):
return

# error_type must be a subclass of Exception and not subclass of Warning
Expand Down
22 changes: 15 additions & 7 deletions torch/_decomp/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
_safe_copy_out,
out_wrapper,
)
from torch.fx.experimental.symbolic_shapes import guard_int
from torch.fx.experimental.symbolic_shapes import expect_true, guard_int
from torch.utils._pytree import tree_flatten, tree_map

DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined]
Expand Down Expand Up @@ -717,12 +717,12 @@ def slice_forward(

if start_val < 0:
start_val = 0
elif start_val >= sizes[dim]:
elif start_val > sizes[dim]:
start_val = sizes[dim]

if end_val < start_val:
end_val = start_val
elif end_val >= sizes[dim]:
elif end_val > sizes[dim]:
end_val = sizes[dim]

storage_offset = self.storage_offset() + start_val * strides[dim]
Expand Down Expand Up @@ -1172,15 +1172,23 @@ def prod(x: List[int]):
def split_with_sizes(
self: Tensor, split_sizes: List[int], dim: int = 0
) -> List[Tensor]:
if sum(split_sizes) != self.shape[dim]:
raise ValueError(
"Split sizes don't add up to the tensor's size in the given dimension"
)
torch._check_with(
ValueError,
sum(split_sizes) == self.shape[dim],
lambda: "Split sizes don't add up to the tensor's size in the given dimension",
)
num_splits = len(split_sizes)
splits = []
start_idx = 0
for i in range(num_splits):
length = split_sizes[i]
torch._check(
length >= 0,
lambda: "split_with_sizes expects split_sizes have only non-negative entries",
)
# We know this is true thanks to the sum, but this assertion helps
# out our internal reasoning
expect_true(start_idx + length <= self.shape[dim])
splits.append(self.narrow(dim, start_idx, length))
start_idx += length
return splits
Expand Down
15 changes: 7 additions & 8 deletions torch/_refs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2923,14 +2923,13 @@ def narrow(
torch._check(length >= 0, lambda: "narrow(): length must be non-negative.")
dim = utils.canonicalize_dim(a.ndim, dim)
dim_length = a.size(dim)
# Start being the end is usually invalid since it's out of bounds. So it's
# not allowed by canonicalize_dim. But for narrow it's valid as long as
# the length is 0, which is handled by the check below.
if start != dim_length:
# Negative start means indexing from the end of dim.
# Note: a dimension isn't being canonicalized here, this reuses
# canonicalize_dim because the semantics are similar.
start = utils.canonicalize_dim(dim_length, start) # type: ignore[arg-type]
torch._check_with(
IndexError,
-dim_length <= start and start <= dim_length, # type: ignore[arg-type]
lambda: f"start out of range (expected to be in range of [{-dim_length}, {dim_length}], but got {start})",
)
if start < 0:
start = start + dim_length
torch._check(
start <= dim_length - length, # type: ignore[arg-type]
lambda: f"start ({start}) + length ({length}) exceeds dimension size ({dim_length}).",
Expand Down
18 changes: 8 additions & 10 deletions torch/csrc/jit/python/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1210,27 +1210,25 @@ void initJITBindings(PyObject* module) {
SYMNODE_SIZES_STRIDES(is_channels_last_strides_2d)
SYMNODE_SIZES_STRIDES(is_channels_last_strides_3d)
SYMNODE_SIZES_STRIDES(is_non_overlapping_and_dense)
// Intentionally don't set file line, as the
// Python backtrace matters more here
.def(
"guard_int",
[](c10::SymNode a) {
return a->guard_int(nullptr, 0);
[](c10::SymNode a, const char* file, int64_t line) {
return a->guard_int(file, line);
})
.def(
"guard_bool",
[](c10::SymNode a) {
return a->guard_bool(nullptr, 0);
[](c10::SymNode a, const char* file, int64_t line) {
return a->guard_bool(file, line);
})
.def(
"guard_float",
[](c10::SymNode a) {
return a->guard_float(nullptr, 0);
[](c10::SymNode a, const char* file, int64_t line) {
return a->guard_float(file, line);
})
.def(
"expect_true",
[](c10::SymNode a) {
return a->expect_true(nullptr, 0);
[](c10::SymNode a, const char* file, int64_t line) {
return a->expect_true(file, line);
})
.def(
"has_hint",
Expand Down
14 changes: 4 additions & 10 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5013,19 +5013,13 @@ def error_inputs_narrow_narrow_copy(op_info, device, *, is_narrow, is_ref):
error_regex=r"Dimension out of range \(expected to be in range of \[-3, 2\], but got -4\)")

# out of bounds start
if not is_narrow and not is_ref and torch.device(device).type == 'cpu':
# narrow_copy_dense_cpu_out
yield ErrorInput(SampleInput(make_arg((L, M, S)), 1, M + 1, 0),
error_type=RuntimeError,
error_regex=r"start \(11\) \+ length \(0\) exceeds dimension size \(10\)\.")
else:
yield ErrorInput(SampleInput(make_arg((L, M, S)), 1, M + 1, 0),
error_type=IndexError,
error_regex=r"Dimension out of range \(expected to be in range of \[-10, 9\], but got 11\)")
yield ErrorInput(SampleInput(make_arg((L, M, S)), 1, M + 1, 0),
error_type=IndexError,
error_regex=r"start out of range \(expected to be in range of \[-10, 10\], but got 11\)")
# out of bounds start (negative)
yield ErrorInput(SampleInput(make_arg((L, M, S)), 1, -M - 1, 0),
error_type=IndexError,
error_regex=r"Dimension out of range \(expected to be in range of \[-10, 9\], but got -11\)")
error_regex=r"start out of range \(expected to be in range of \[-10, 10\], but got -11\)")

# out of bounds length
yield ErrorInput(SampleInput(make_arg((S, L, M)), 2, 0, M + 1),
Expand Down

0 comments on commit 5673c08

Please sign in to comment.