diff --git a/backends/arm/operators/op_slice.py b/backends/arm/operators/op_slice.py index a9e1a6cdbe2..27ae977a5bc 100644 --- a/backends/arm/operators/op_slice.py +++ b/backends/arm/operators/op_slice.py @@ -23,6 +23,18 @@ class SliceVisitor(NodeVisitor): def __init__(self, *args): super().__init__(*args) + def _fixup_start(self, start, shape, dim): + if start.number < 0: + return start.number % shape[dim] + else: + return start.number + + def _fixup_end(self, end, shape, dim): + if end.number < 0: + return end.number % shape[dim] + else: + return min(end.number, shape[dim]) + def define_node( self, node: Node, @@ -42,17 +54,21 @@ def define_node( # Translate and check parameters in Pytorch dim order. shape = input_node.shape dim = dim.number - if end.number < 0: - end_index = end.number % shape[dim] - else: - end_index = min(end.number, shape[dim]) - size = end_index - start.number + + start_index = self._fixup_start(start, shape, dim) + end_index = self._fixup_end(end, shape, dim) + size = end_index - start_index + assert size > 0 assert size <= shape[dim] # Convert aten args to Tosa's start and size attributes and in TOSA dim order. attr = ts.TosaSerializerAttribute() - start_attr = [start.number if i == dim else 0 for i in input_node.dim_order] + + start_attr = [ + self._fixup_start(start, shape, dim) if i == dim else 0 + for i in input_node.dim_order + ] size_attr = [size if i == dim else shape[i] for i in input_node.dim_order] attr.SliceAttribute(start_attr, size_attr) diff --git a/backends/arm/test/ops/test_slice.py b/backends/arm/test/ops/test_slice.py index 6900fda3abe..91ef51cc2a2 100644 --- a/backends/arm/test/ops/test_slice.py +++ b/backends/arm/test/ops/test_slice.py @@ -16,23 +16,21 @@ from executorch.exir.backend.compile_spec_schema import CompileSpec from parameterized import parameterized +test_data_suite = [ + (torch.ones(10), [(3, -3)]), + (torch.ones(10), [(-8, 3)]), + (torch.ones(10, 10), [(1, 3), (3, None)]), + (torch.ones(10, 10, 10), [(0, 7), (0, None), (0, 8)]), + (torch.ones((1, 12, 10, 10)), [(None, None), (None, 5), (3, 5), (4, 10)]), +] + class TestSimpleSlice(unittest.TestCase): class Slice(torch.nn.Module): - - sizes = [(10), (10, 10), (10, 10, 10), ((1, 12, 10, 10))] - test_tensors = [(torch.ones(n),) for n in sizes] - - def forward(self, x: torch.Tensor): - if x.dim() == 1: - return x[3:-3] - elif x.dim() == 2: - return x[1:3, 3:] - elif x.dim() == 3: - return x[0:7, 0:, 0:8] - elif x.dim() == 4: - return x[:, :5, 3:5, 4:10] + def forward(self, x: torch.Tensor, s: list[tuple[int, int]]): + slices = [slice(*i) for i in s] + return x[slices] def _test_slice_tosa_MI_pipeline( self, module: torch.nn.Module, test_data: torch.Tensor @@ -112,25 +110,29 @@ def _test_slice_u85_BI_pipeline( common.get_u85_compile_spec(), module, test_data ) - @parameterized.expand(Slice.test_tensors) + @parameterized.expand(test_data_suite) @pytest.mark.tosa_ref_model - def test_slice_tosa_MI(self, tensor): - self._test_slice_tosa_MI_pipeline(self.Slice(), (tensor,)) + def test_slice_tosa_MI(self, tensor: torch.Tensor, slices: list[tuple[int, int]]): + self._test_slice_tosa_MI_pipeline(self.Slice(), (tensor, slices)) - @parameterized.expand(Slice.test_tensors[:2]) + @parameterized.expand(test_data_suite) @pytest.mark.tosa_ref_model - def test_slice_nchw_tosa_BI(self, test_tensor: torch.Tensor): - self._test_slice_tosa_BI_pipeline(self.Slice(), (test_tensor,)) + def test_slice_nchw_tosa_BI( + self, tensor: torch.Tensor, slices: list[tuple[int, int]] + ): + self._test_slice_tosa_BI_pipeline(self.Slice(), (tensor, slices)) - @parameterized.expand(Slice.test_tensors[2:]) + @parameterized.expand(test_data_suite) @pytest.mark.tosa_ref_model - def test_slice_nhwc_tosa_BI(self, test_tensor: torch.Tensor): - self._test_slice_tosa_BI_pipeline(self.Slice(), (test_tensor,)) + def test_slice_nhwc_tosa_BI( + self, tensor: torch.Tensor, slices: list[tuple[int, int]] + ): + self._test_slice_tosa_BI_pipeline(self.Slice(), (tensor, slices)) - @parameterized.expand(Slice.test_tensors) - def test_slice_u55_BI(self, test_tensor: torch.Tensor): - self._test_slice_u55_BI_pipeline(self.Slice(), (test_tensor,)) + @parameterized.expand(test_data_suite) + def test_slice_u55_BI(self, tensor: torch.Tensor, slices: list[tuple[int, int]]): + self._test_slice_u55_BI_pipeline(self.Slice(), (tensor, slices)) - @parameterized.expand(Slice.test_tensors) - def test_slice_u85_BI(self, test_tensor: torch.Tensor): - self._test_slice_u85_BI_pipeline(self.Slice(), (test_tensor,)) + @parameterized.expand(test_data_suite) + def test_slice_u85_BI(self, tensor: torch.Tensor, slices: list[tuple[int, int]]): + self._test_slice_u85_BI_pipeline(self.Slice(), (tensor, slices))