diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index c26cd8fb078..aa0f391781b 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -197,6 +197,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule: self.add_pass(CastBoolToInt8Pass()) self.add_pass(DecomposeSinhPass()) self.add_pass(DecomposeSignPass()) + self.add_pass(DecomposeDivTensorModePass()) self.add_pass(ReplaceScalarWithTensorArgPassTOSAMI()) self.add_pass(DecomposeEmbeddingPass()) self.add_pass(FuseQuantizedActivationPass()) @@ -215,7 +216,6 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule: DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec) ) self.add_pass(DecomposeNotEqualPass()) - self.add_pass(DecomposeDivTensorModePass()) self.add_pass(DecomposeDivPass()) self.add_pass(DecomposeSoftmaxPass()) self.add_pass(DecomposeGeluPass()) @@ -285,6 +285,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): self.add_pass(CastBoolToInt8Pass()) self.add_pass(DecomposeSignPass()) self.add_pass(DecomposeAddmmPass()) + self.add_pass(DecomposeDivTensorModePass()) self.add_pass(ReplaceScalarWithTensorArgPassTOSABI()) self.add_pass(ScalarsToAttributePass()) self.add_pass(DecomposeGroupNormPass()) @@ -294,7 +295,6 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): self.add_pass(DecomposeNotEqualPass()) self.add_pass(DecomposeCosineSimilarityPass()) self.add_pass(DecomposeGluPass()) - self.add_pass(DecomposeDivTensorModePass()) self.add_pass(DecomposeDivPass()) self.add_pass(DecomposeLeakyReLUPass()) self.add_pass(DecomposeLinearVectorNormPass()) diff --git a/backends/arm/test/ops/test_div_tensor_mode.py b/backends/arm/test/ops/test_div_tensor_mode.py index f78aca85bcd..909b83bd97f 100644 --- a/backends/arm/test/ops/test_div_tensor_mode.py +++ b/backends/arm/test/ops/test_div_tensor_mode.py @@ -4,7 +4,6 @@ from typing import Tuple -import pytest import torch from executorch.backends.arm.test import common @@ -19,13 +18,6 @@ input_tt = Tuple[torch.Tensor, torch.Tensor] -def make_float_div_inputs(B: int = 4, T: int = 64) -> input_tt: - x = torch.randn(B, T) - # guard against zero in denominator - y = torch.randn(B, T).abs() + 1e-3 - return x, y - - class DivTensorModeFloat(torch.nn.Module): """ torch.div(x, y, rounding_mode=mode) with @@ -44,11 +36,24 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return torch.div(x, y, rounding_mode=self.mode) -@pytest.mark.parametrize("mode", [None, "floor", "trunc"]) -def test_div_tensor_mode_tosa_FP(mode): +test_data = { + "mode_none": lambda: (None, (torch.randn(4, 8), torch.randn(4, 8).abs() + 1e-3)), + "mode_floor": lambda: ( + "floor", + (torch.randn(4, 8), torch.randn(4, 8).abs() + 1e-3), + ), + "mode_trunc": lambda: ( + "trunc", + (torch.randn(4, 8), torch.randn(4, 8).abs() + 1e-3), + ), + "int_denominator": lambda: (None, (torch.randn(4, 8), 2)), +} + +@common.parametrize("data", test_data) +def test_div_tensor_mode_tosa_FP(data): + mode, inputs = data() model = DivTensorModeFloat(mode) - inputs = make_float_div_inputs() pipeline = TosaPipelineFP[input_tt]( model, @@ -61,11 +66,10 @@ def test_div_tensor_mode_tosa_FP(mode): pipeline.run() -@pytest.mark.parametrize("mode", [None, "floor", "trunc"]) -def test_div_tensor_mode_tosa_INT(mode): - +@common.parametrize("data", test_data) +def test_div_tensor_mode_tosa_INT(data): + mode, inputs = data() model = DivTensorModeFloat(mode) - inputs = make_float_div_inputs() pipeline = TosaPipelineINT[input_tt]( model, @@ -79,11 +83,12 @@ def test_div_tensor_mode_tosa_INT(mode): @common.XfailIfNoCorstone300 -@pytest.mark.parametrize("mode", [None, "floor"]) -def test_div_tensor_mode_u55_INT(mode): - +@common.parametrize( + "data", test_data, xfails={"mode_trunc": "CPU op missing in unittests"} +) +def test_div_tensor_mode_u55_INT(data): + mode, inputs = data() model = DivTensorModeFloat(mode) - inputs = make_float_div_inputs() pipeline = EthosU55PipelineINT[input_tt]( model, @@ -97,11 +102,10 @@ def test_div_tensor_mode_u55_INT(mode): @common.XfailIfNoCorstone320 -@pytest.mark.parametrize("mode", [None, "floor", "trunc"]) -def test_div_tensor_mode_u85_INT(mode): - +@common.parametrize("data", test_data) +def test_div_tensor_mode_u85_INT(data): + mode, inputs = data() model = DivTensorModeFloat(mode) - inputs = make_float_div_inputs() pipeline = EthosU85PipelineINT[input_tt]( model, @@ -115,11 +119,10 @@ def test_div_tensor_mode_u85_INT(mode): @common.SkipIfNoModelConverter -@pytest.mark.parametrize("mode", [None, "floor", "trunc"]) -def test_div_tensor_mode_vgf_INT(mode): - +@common.parametrize("data", test_data) +def test_div_tensor_mode_vgf_INT(data): + mode, inputs = data() model = DivTensorModeFloat(mode) - inputs = make_float_div_inputs() pipeline = VgfPipeline[input_tt]( model, @@ -134,11 +137,10 @@ def test_div_tensor_mode_vgf_INT(mode): @common.SkipIfNoModelConverter -@pytest.mark.parametrize("mode", [None, "floor", "trunc"]) -def test_div_tensor_mode_vgf_FP(mode): - +@common.parametrize("data", test_data) +def test_div_tensor_mode_vgf_FP(data): + mode, inputs = data() model = DivTensorModeFloat(mode) - inputs = make_float_div_inputs() pipeline = VgfPipeline[input_tt]( model,