Skip to content

Commit

Permalink
[jit] allow slicing multiple dimensions with indicies (#45239)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #45239

Test Plan: Imported from OSS

Reviewed By: gmagogsfm

Differential Revision: D23886919

Pulled By: Lilyjjo

fbshipit-source-id: d45c2a550fa8df9960cf2ab5da9d1ae0058a967a
  • Loading branch information
Lilyjjo authored and facebook-github-bot committed Oct 5, 2020
1 parent f11f9a8 commit 9a668f9
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 3 deletions.
37 changes: 37 additions & 0 deletions test/jit/test_list_dict.py
Expand Up @@ -408,6 +408,43 @@ def test_over_slice():
return a[3:10] == [3, 4]
self.checkScript(test_backward_slice, ())

def test_slice_index(self):
a = torch.tensor(
[
[[1, 11], [2, 22]],
[[3, 33], [4, 44]],
[[5, 55], [6, 66]],
]
)

def test_index_slice1(x):
x = x[:, :, [0, 1]]
return x
self.checkScript(test_index_slice1, (a,))

def test_index_slice2(x):
x = x[[2, 1, 0], :, :]
return x
self.checkScript(test_index_slice2, (a,))

def test_index_slice3(x):
x = x[[0, 1], :, [1]]
return x
self.checkScript(test_index_slice3, (a,))

def test_index_slice_empty_list(x):
empty_list: List[int] = []
x = x[empty_list, :, :]
return x
self.checkScript(test_index_slice_empty_list, (a,))

def test_index_slice_out_of_bounds_index(x):
x = x[[4], :, :]
return x
with self.assertRaisesRegex(RuntimeError, "index 4 is out of bounds for dimension 0 with size 3"):
self.checkScript(test_index_slice_out_of_bounds_index, (a,))


def test_mutable_list_append(self):
def test_append():
a = [0, 1]
Expand Down
5 changes: 2 additions & 3 deletions torch/jit/frontend.py
Expand Up @@ -686,11 +686,10 @@ def build_SliceExpr(ctx, base, slice_expr):
return SliceExpr(base.range(), lower, upper, step)

def build_Index(ctx, base, index_expr):
if isinstance(index_expr.value, ast.Tuple) or \
isinstance(index_expr.value, ast.List):
if isinstance(index_expr.value, ast.Tuple):
raise NotSupportedError(base.range(),
"slicing multiple dimensions with "
"sequences not supported yet")
"tuples not supported yet")
return build_expr(ctx, index_expr.value)

def build_ExtSlice(ctx, base, extslice):
Expand Down

0 comments on commit 9a668f9

Please sign in to comment.