Skip to content

Commit

Permalink
[WIP] [jit] allow slicing multiple dimensions with indices
Browse files Browse the repository at this point in the history
Currently investigating issue where indexing with an empty list
in eager returns a tensor with a zero-sized dim in the empty
list spot *OTOH* the scripted module will throw this error:
" RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
RuntimeError: Input must be of ints, floats, or bools, got Tensor
Empty lists default to List[Tensor]. Add a variable annotation to the
assignment to create an empty list of another type
(torch.jit.annotate(List[T, []]) where T is the type of elements in the
list for Python 2) "

ghstack-source-id: 10c0d1a8d164e9dfe8b1ec6a008782d98388a309
Pull Request resolved: #45239
  • Loading branch information
Lilyjjo committed Oct 5, 2020
1 parent f65ab89 commit 7f54254
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 3 deletions.
43 changes: 43 additions & 0 deletions test/jit/test_list_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,49 @@ def test_over_slice():
return a[3:10] == [3, 4]
self.checkScript(test_backward_slice, ())

def test_slice_index(self):
a = torch.tensor(
[
[[1, 11], [2, 22]],
[[3, 33], [4, 44]],
[[5, 55], [6, 66]],
]
)

def test_index_slice1(x):
x = x[:, :, [0, 1]]
return x

def test_index_slice2(x):
x = x[[2, 1, 0], :, :]
return x

def test_index_slice3(x):
x = x[[0, 1], :, [1]]
return x

def test_index_slice4(x):
x = x[[4], :, :]
return x

def test_index_slice5(x):
x = x[[], :, :]
return x

self.checkScript(test_index_slice1, (a,))
self.checkScript(test_index_slice2, (a,))
self.checkScript(test_index_slice3, (a,))

with self.assertRaisesRegex(RuntimeError, "index 4 is out of bounds for dimension 0 with size 3"):
self.checkScript(test_index_slice4, (a,))

with self.assertRaises(RuntimeError):
# using indexing with empty list resolves to tensor incorrectly,
# users need to add type annotations currently as a workaround
scripted = torch.jit.script(test_index_slice5)
scripted(a)


def test_mutable_list_append(self):
def test_append():
a = [0, 1]
Expand Down
5 changes: 2 additions & 3 deletions torch/jit/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,11 +686,10 @@ def build_SliceExpr(ctx, base, slice_expr):
return SliceExpr(base.range(), lower, upper, step)

def build_Index(ctx, base, index_expr):
if isinstance(index_expr.value, ast.Tuple) or \
isinstance(index_expr.value, ast.List):
if isinstance(index_expr.value, ast.Tuple):
raise NotSupportedError(base.range(),
"slicing multiple dimensions with "
"sequences not supported yet")
"tuples not supported yet")
return build_expr(ctx, index_expr.value)

def build_ExtSlice(ctx, base, extslice):
Expand Down

0 comments on commit 7f54254

Please sign in to comment.