diff --git a/backends/arm/test/models/test_w2l_arm.py b/backends/arm/test/models/test_w2l_arm.py index 32b25a18fd8..c627cd7f887 100644 --- a/backends/arm/test/models/test_w2l_arm.py +++ b/backends/arm/test/models/test_w2l_arm.py @@ -36,7 +36,6 @@ class TestW2L(unittest.TestCase): input_frames = 400 num_features = 1 - w2l = models.Wav2Letter(num_features=num_features).eval() model_example_inputs = get_test_inputs(batch_size, num_features, input_frames) all_operators = [ @@ -45,11 +44,17 @@ class TestW2L(unittest.TestCase): "executorch_exir_dialects_edge__ops_aten_relu_default", ] + @staticmethod + def create_model(input_type: str = "waveform"): + return models.Wav2Letter( + num_features=TestW2L.num_features, input_type=input_type + ).eval() + @pytest.mark.slow # about 3min on std laptop def test_w2l_tosa_FP(): pipeline = TosaPipelineFP[input_t]( - TestW2L.w2l, + TestW2L.create_model(), TestW2L.model_example_inputs, aten_op=[], exir_op=TestW2L.all_operators, @@ -62,7 +67,7 @@ def test_w2l_tosa_FP(): @pytest.mark.flaky def test_w2l_tosa_INT(): pipeline = TosaPipelineINT[input_t]( - TestW2L.w2l, + TestW2L.create_model(), TestW2L.model_example_inputs, aten_op=[], exir_op=TestW2L.all_operators, @@ -74,12 +79,14 @@ def test_w2l_tosa_INT(): @pytest.mark.slow @common.XfailIfNoCorstone300 @pytest.mark.xfail( - reason="MLETORCH-1009: Wav2Letter fails on U55 due to unsupported conditions", - strict=False, + reason="Wav2Letter fails on U55 due to insufficient memory", + strict=True, ) def test_w2l_u55_INT(): pipeline = EthosU55PipelineINT[input_t]( - TestW2L.w2l, + # Use "power_spectrum" variant because the default ("waveform") has a + # conv1d layer with an unsupported stride size. + TestW2L.create_model("power_spectrum"), TestW2L.model_example_inputs, aten_ops=[], exir_ops=[], @@ -94,7 +101,7 @@ def test_w2l_u55_INT(): @pytest.mark.skip(reason="Intermittent timeout issue: MLETORCH-856") def test_w2l_u85_INT(): pipeline = EthosU85PipelineINT[input_t]( - TestW2L.w2l, + TestW2L.create_model(), TestW2L.model_example_inputs, aten_ops=[], exir_ops=[], @@ -108,7 +115,7 @@ def test_w2l_u85_INT(): @pytest.mark.slow def test_w2l_vgf_INT(): pipeline = VgfPipeline[input_t]( - TestW2L.w2l, + TestW2L.create_model(), TestW2L.model_example_inputs, aten_op=[], exir_op=TestW2L.all_operators, @@ -121,7 +128,7 @@ def test_w2l_vgf_INT(): @common.SkipIfNoModelConverter def test_w2l_vgf_FP(): pipeline = VgfPipeline[input_t]( - TestW2L.w2l, + TestW2L.create_model(), TestW2L.model_example_inputs, aten_op=[], exir_op=TestW2L.all_operators,