diff --git a/backends/arm/test/passes/test_fuse_constant_ops_pass.py b/backends/arm/test/passes/test_fuse_constant_ops_pass.py index 5e759d7a824..4ec6942430f 100644 --- a/backends/arm/test/passes/test_fuse_constant_ops_pass.py +++ b/backends/arm/test/passes/test_fuse_constant_ops_pass.py @@ -4,7 +4,6 @@ # LICENSE file in the root directory of this source tree. import operator -import unittest from typing import Tuple import torch @@ -13,10 +12,7 @@ FuseConstantArgsPass, ) from executorch.backends.arm.test import common -from executorch.backends.arm.test.tester.test_pipeline import ( - PassPipeline, - TosaPipelineBI, -) +from executorch.backends.arm.test.tester.test_pipeline import PassPipeline input_t = Tuple[torch.Tensor] # Input x @@ -111,15 +107,14 @@ def test_fuse_const_ops_tosa_MI(module: torch.nn.Module): pipeline.run() -@unittest.skip("Test failing on internal CI") @common.parametrize("module", modules) def test_fuse_const_ops_tosa_BI(module: torch.nn.Module): - pipeline = TosaPipelineBI[input_t]( + pipeline = PassPipeline[input_t]( module, (torch.rand(10, 10),), - [], - [], quantize=True, - use_to_edge_transform_and_lower=True, + ops_before_pass=module.ops_before_pass, + ops_after_pass=module.ops_after_pass, + passes_with_exported_program=[ComputeConstantOpsAOT, FuseConstantArgsPass], ) pipeline.run()