From 7ba99a6690dd7d1d4df0d67f913151a038664587 Mon Sep 17 00:00:00 2001 From: Gregory Comer Date: Mon, 13 Apr 2026 14:18:26 -0700 Subject: [PATCH] Update emformer tests to avoid 0/1 specialization issue Summary: The tests break in some dynamo configurations. Set min batch to 2 to be exportable in all. Reviewed By: abhinaykukkadapu Differential Revision: D100391669 --- backends/test/harness/tester.py | 3 +++ backends/xnnpack/test/models/emformer_rnnt.py | 14 ++++++++++++-- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/backends/test/harness/tester.py b/backends/test/harness/tester.py index e96904aedd8..ea5fd21cb99 100644 --- a/backends/test/harness/tester.py +++ b/backends/test/harness/tester.py @@ -131,6 +131,9 @@ def generate_random_inputs(self): assert isinstance(self.example_inputs[arg_idx], torch.Tensor) ex_shape = list(self.example_inputs[arg_idx].shape) dynamic_dim_spec = self.dynamic_shapes[arg_idx] + if dynamic_dim_spec is None or dynamic_dim_spec == {}: + input_shapes.append(torch.Size(ex_shape)) + continue for dim_idx, dim_spec in dynamic_dim_spec.items(): assert dim_idx < len(ex_shape) if isinstance(dim_spec, torch.export.dynamic_shapes._DerivedDim): diff --git a/backends/xnnpack/test/models/emformer_rnnt.py b/backends/xnnpack/test/models/emformer_rnnt.py index 7881be94921..5744ae6dfc1 100644 --- a/backends/xnnpack/test/models/emformer_rnnt.py +++ b/backends/xnnpack/test/models/emformer_rnnt.py @@ -50,6 +50,12 @@ def test_fp32_emformer_joiner(self): def test_fp32_emformer_joiner_dynamic(self): joiner = self.Joiner() + example_inputs = ( + torch.rand([2, 128, 1024]), + torch.tensor([128]), + torch.rand([2, 128, 1024]), + torch.tensor([128]), + ) dynamic_shapes = ( {0: torch.export.Dim("batch", min=1, max=4)}, None, @@ -57,7 +63,7 @@ def test_fp32_emformer_joiner_dynamic(self): None, ) ( - Tester(joiner, joiner.get_example_inputs(), dynamic_shapes=dynamic_shapes) + Tester(joiner, example_inputs, dynamic_shapes=dynamic_shapes) .export() .to_edge_transform_and_lower() .check(["torch.ops.higher_order.executorch_call_delegate"]) @@ -117,6 +123,10 @@ def test_fp32_emformer_transcriber(self): def test_fp32_emformer_transcriber_dynamic(self): transcriber = self.Transcriber() + example_inputs = ( + torch.randn(2, 128, 80), + torch.tensor([128]), + ) dynamic_shapes = ( {0: torch.export.Dim("batch", min=1, max=4)}, None, @@ -124,7 +134,7 @@ def test_fp32_emformer_transcriber_dynamic(self): ( Tester( transcriber, - transcriber.get_example_inputs(), + example_inputs, dynamic_shapes=dynamic_shapes, ) .export()