Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions backends/arm/test/ops/test_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
import unittest
from typing import Tuple

import pytest

import torch

from executorch.backends.arm.test import common
from executorch.backends.arm.test import common, conftest
from executorch.backends.arm.test.tester.arm_tester import ArmTester
from executorch.exir.backend.compile_spec_schema import CompileSpec
from parameterized import parameterized
Expand All @@ -35,7 +37,7 @@ def forward(self, x: torch.Tensor):
def _test_slice_tosa_MI_pipeline(
self, module: torch.nn.Module, test_data: torch.Tensor
):
(
tester = (
ArmTester(
module,
example_inputs=test_data,
Expand All @@ -48,14 +50,16 @@ def _test_slice_tosa_MI_pipeline(
.partition()
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
.run_method_and_compare_outputs(inputs=test_data)
)

if conftest.is_option_enabled("tosa_ref_model"):
tester.run_method_and_compare_outputs(inputs=test_data)

def _test_slice_tosa_BI_pipeline(
self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
):

(
tester = (
ArmTester(
module,
example_inputs=test_data,
Expand All @@ -68,9 +72,11 @@ def _test_slice_tosa_BI_pipeline(
.partition()
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
.run_method_and_compare_outputs(inputs=test_data, qtol=1)
)

if conftest.is_option_enabled("tosa_ref_model"):
tester.run_method_and_compare_outputs(inputs=test_data, qtol=1)

def _test_slice_ethos_BI_pipeline(
self,
compile_spec: list[CompileSpec],
Expand Down Expand Up @@ -107,14 +113,17 @@ def _test_slice_u85_BI_pipeline(
)

@parameterized.expand(Slice.test_tensors)
@pytest.mark.tosa_ref_model
def test_slice_tosa_MI(self, tensor):
self._test_slice_tosa_MI_pipeline(self.Slice(), (tensor,))

@parameterized.expand(Slice.test_tensors[:2])
@pytest.mark.tosa_ref_model
def test_slice_nchw_tosa_BI(self, test_tensor: torch.Tensor):
self._test_slice_tosa_BI_pipeline(self.Slice(), (test_tensor,))

@parameterized.expand(Slice.test_tensors[2:])
@pytest.mark.tosa_ref_model
def test_slice_nhwc_tosa_BI(self, test_tensor: torch.Tensor):
self._test_slice_tosa_BI_pipeline(self.Slice(), (test_tensor,))

Expand Down
3 changes: 2 additions & 1 deletion backends/arm/test/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ def define_arm_tests():
test_files.remove("passes/test_ioquantization_pass.py")

# Operators
test_files += native.glob(["ops/test_linear.py"])
test_files += ["ops/test_linear.py"]
test_files += ["ops/test_slice.py"]

TESTS = {}

Expand Down
Loading