Skip to content

Commit

Permalink
Use expect_true to make split with unbacked sizes work.
Browse files Browse the repository at this point in the history
This pattern shows up in torchrec KeyedJaggedTensor.  Most
of the change in this PR is mechanical: whenever we failed
an unbacked symint test due to just error checking, replace the
conditional with something that calls expect_true (e.g.,
torch._check or TORCH_SYM_CHECK).

Some of the changes are a bit more nuanced, I've commented on the PR
accordingly.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

ghstack-source-id: c93177d89e1a1d6c15f2b0765c77c9f6282ec891
Pull Request resolved: #106788
  • Loading branch information
ezyang committed Aug 8, 2023
1 parent 49af66c commit 30136d3
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 26 deletions.
20 changes: 14 additions & 6 deletions aten/src/ATen/native/TensorShape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1349,8 +1349,12 @@ Tensor narrow(const Tensor& self, int64_t dim, int64_t start, int64_t length) {
TORCH_CHECK(self.dim() > 0, "narrow() cannot be applied to a 0-dim tensor.");
TORCH_CHECK(length >= 0, "narrow(): length must be non-negative.");
auto cur_size = self.size(dim);
if (start != cur_size) { // start being the end is valid, but not a valid dim specification.
start = maybe_wrap_dim(start, cur_size);
TORCH_CHECK_INDEX(
-cur_size <= start && start <= cur_size,
"start out of range (expected to be in range of [", -cur_size, ", ", cur_size, "], but got ", start, ")"
)
if (start < 0) {
start = start + cur_size;
}
TORCH_CHECK(start <= cur_size - length,
"start (", start, ") + length (", length, ") exceeds dimension size (", cur_size, ").");
Expand All @@ -1359,12 +1363,16 @@ Tensor narrow(const Tensor& self, int64_t dim, int64_t start, int64_t length) {

Tensor narrow_symint(const Tensor& self, int64_t dim, SymInt start, SymInt length) {
TORCH_CHECK(self.dim() > 0, "narrow() cannot be applied to a 0-dim tensor.");
TORCH_CHECK(length >= 0, "narrow(): length must be non-negative.");
TORCH_SYM_CHECK(length.sym_ge(0), "narrow(): length must be non-negative.");
auto cur_size = self.sym_size(dim);
if (start != cur_size) { // start being the end is valid, but not a valid dim specification.
start = maybe_wrap_dim(start, cur_size);
TORCH_CHECK_INDEX(
((-cur_size).sym_le(start).sym_and(start.sym_le(cur_size))).expect_true(__FILE__, __LINE__),
"start out of range (expected to be in range of [", -cur_size, ", ", cur_size, "], but got ", start, ")"
)
if (start < 0) {
start = start + cur_size;
}
TORCH_CHECK(start <= cur_size - length,
TORCH_SYM_CHECK(start.sym_le(cur_size - length),
"start (", start, ") + length (", length, ") exceeds dimension size (", cur_size, ").");
return at::slice_symint(self, dim, start, start + length, 1);
}
Expand Down
26 changes: 26 additions & 0 deletions test/test_proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1174,6 +1174,32 @@ def forward(self, a_1):
empty = torch.ops.aten.empty.memory_format([add], device = device(type='cpu'), pin_memory = False); add = None
return empty""")

def test_split_unbacked_sizes(self):
def f(lengths, values):
# tolist not directly supported atm
sizes = [lengths[i].item() for i in range(lengths.size(0))]
for s in sizes:
constrain_range(s, min=2) # TODO: better as constrain_as_size
return torch.split(values, sizes)

r = str(make_fx(f, tracing_mode="symbolic")(
torch.tensor([2, 3, 4]),
torch.randn(9)
).code).strip()
self.assertExpectedInline(r, """\
def forward(self, lengths_1, values_1):
select = torch.ops.aten.select.int(lengths_1, 0, 0)
_local_scalar_dense = torch.ops.aten._local_scalar_dense.default(select); select = None
select_1 = torch.ops.aten.select.int(lengths_1, 0, 1)
_local_scalar_dense_1 = torch.ops.aten._local_scalar_dense.default(select_1); select_1 = None
select_2 = torch.ops.aten.select.int(lengths_1, 0, 2); lengths_1 = None
_local_scalar_dense_2 = torch.ops.aten._local_scalar_dense.default(select_2); select_2 = None
split_with_sizes = torch.ops.aten.split_with_sizes.default(values_1, [_local_scalar_dense, _local_scalar_dense_1, _local_scalar_dense_2]); values_1 = _local_scalar_dense = _local_scalar_dense_1 = _local_scalar_dense_2 = None
getitem = split_with_sizes[0]
getitem_1 = split_with_sizes[1]
getitem_2 = split_with_sizes[2]; split_with_sizes = None
return (getitem, getitem_1, getitem_2)""") # noqa: B950

def test_invalidate_nonzero(self):
ok = False

Expand Down
2 changes: 1 addition & 1 deletion torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,7 +948,7 @@ def _check_with(error_type, cond: Union[builtins.bool, SymBool], message: Callab
if not isinstance(cond, (builtins.bool, torch.SymBool)):
raise TypeError(f'cond must be a bool, but got {type(cond)}')

if cond:
if torch.fx.experimental.symbolic_shapes.expect_true(cond):
return

# error_type must be a subclass of Exception and not subclass of Warning
Expand Down
22 changes: 15 additions & 7 deletions torch/_decomp/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
_safe_copy_out,
out_wrapper,
)
from torch.fx.experimental.symbolic_shapes import guard_int
from torch.fx.experimental.symbolic_shapes import expect_true, guard_int
from torch.utils._pytree import tree_flatten, tree_map

DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined]
Expand Down Expand Up @@ -690,12 +690,12 @@ def slice_forward(

if start_val < 0:
start_val = 0
elif start_val >= sizes[dim]:
elif start_val > sizes[dim]:
start_val = sizes[dim]

if end_val < start_val:
end_val = start_val
elif end_val >= sizes[dim]:
elif end_val > sizes[dim]:
end_val = sizes[dim]

storage_offset = self.storage_offset() + start_val * strides[dim]
Expand Down Expand Up @@ -1145,15 +1145,23 @@ def prod(x: List[int]):
def split_with_sizes(
self: Tensor, split_sizes: List[int], dim: int = 0
) -> List[Tensor]:
if sum(split_sizes) != self.shape[dim]:
raise ValueError(
"Split sizes don't add up to the tensor's size in the given dimension"
)
torch._check_with(
ValueError,
sum(split_sizes) == self.shape[dim],
lambda: "Split sizes don't add up to the tensor's size in the given dimension",
)
num_splits = len(split_sizes)
splits = []
start_idx = 0
for i in range(num_splits):
length = split_sizes[i]
torch._check(
length >= 0,
lambda: "split_with_sizes expects split_sizes have only non-negative entries",
)
# We know this is true thanks to the sum, but this assertion helps
# out our internal reasoning
expect_true(start_idx + length <= self.shape[dim])
splits.append(self.narrow(dim, start_idx, length))
start_idx += length
return splits
Expand Down
18 changes: 8 additions & 10 deletions torch/csrc/jit/python/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1202,27 +1202,25 @@ void initJITBindings(PyObject* module) {
SYMNODE_SIZES_STRIDES(is_channels_last_strides_2d)
SYMNODE_SIZES_STRIDES(is_channels_last_strides_3d)
SYMNODE_SIZES_STRIDES(is_non_overlapping_and_dense)
// Intentionally don't set file line, as the
// Python backtrace matters more here
.def(
"guard_int",
[](c10::SymNode a) {
return a->guard_int(nullptr, 0);
[](c10::SymNode a, const char* file, int64_t line) {
return a->guard_int(file, line);
})
.def(
"guard_bool",
[](c10::SymNode a) {
return a->guard_bool(nullptr, 0);
[](c10::SymNode a, const char* file, int64_t line) {
return a->guard_bool(file, line);
})
.def(
"guard_float",
[](c10::SymNode a) {
return a->guard_float(nullptr, 0);
[](c10::SymNode a, const char* file, int64_t line) {
return a->guard_float(file, line);
})
.def(
"expect_true",
[](c10::SymNode a) {
return a->expect_true(nullptr, 0);
[](c10::SymNode a, const char* file, int64_t line) {
return a->expect_true(file, line);
})
.def(
"has_hint",
Expand Down
4 changes: 2 additions & 2 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5016,11 +5016,11 @@ def error_inputs_narrow_narrow_copy(op_info, device, *, is_narrow, is_ref):
else:
yield ErrorInput(SampleInput(make_arg((L, M, S)), 1, M + 1, 0),
error_type=IndexError,
error_regex=r"Dimension out of range \(expected to be in range of \[-10, 9\], but got 11\)")
error_regex=r"start out of range \(expected to be in range of \[-10, 10\], but got 11\)")
# out of bounds start (negative)
yield ErrorInput(SampleInput(make_arg((L, M, S)), 1, -M - 1, 0),
error_type=IndexError,
error_regex=r"Dimension out of range \(expected to be in range of \[-10, 9\], but got -11\)")
error_regex=r"start out of range \(expected to be in range of \[-10, 10\], but got -11\)")

# out of bounds length
yield ErrorInput(SampleInput(make_arg((S, L, M)), 2, 0, M + 1),
Expand Down

0 comments on commit 30136d3

Please sign in to comment.