From 08280ccb3cfe38a98b9bd577dd92f74074e35d75 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Fri, 24 Oct 2025 11:20:40 -0700 Subject: [PATCH 1/2] Fix nxp unittests Let's not hard code the number of nodes in the graph --- backends/nxp/tests/test_batch_norm_fusion.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/backends/nxp/tests/test_batch_norm_fusion.py b/backends/nxp/tests/test_batch_norm_fusion.py index 788d04c6dad..a9646a147ea 100644 --- a/backends/nxp/tests/test_batch_norm_fusion.py +++ b/backends/nxp/tests/test_batch_norm_fusion.py @@ -105,10 +105,8 @@ def test_batch_norm_conv_fusing(bias: bool, input_shape: list[int]): og_nodes = list(program.graph.nodes) transformed_nodes = list(graph_module_out.graph.nodes) - assert len(og_nodes) == (11 if bias else 10) - assert og_nodes[9 if bias else 8].target.__name__ == "batch_norm.default" + assert any(node.target.__name__ == "batch_norm.default" for node in og_nodes) - assert len(transformed_nodes) == 5 assert not any( node.op == "call_function" and "batch_norm" in node.target.__name__ for node in transformed_nodes @@ -118,7 +116,7 @@ def test_batch_norm_conv_fusing(bias: bool, input_shape: list[int]): input_data = torch.randn(input_shape, dtype=torch.float32) out1 = og_module(input_data).detach().numpy() out2 = graph_module_out(input_data).detach().numpy() - assert np.allclose(out1, out2, atol=3.0e-7) + torch.testing.assert_close(out1, out2) @pytest.mark.parametrize( @@ -139,10 +137,8 @@ def test_batch_norm_linear_fusing(bias: bool): og_nodes = list(og_module.graph.nodes) transformed_nodes = list(graph_module_out.graph.nodes) - assert len(og_nodes) == (11 if bias else 10) - assert og_nodes[8 if bias else 7].target.__name__ == "linear.default" + assert any(node.target.__name__ == "linear.default" for node in og_nodes) - assert len(transformed_nodes) == 5 assert not any( node.op == "call_function" and "batch_norm" in node.target.__name__ for node in transformed_nodes @@ -152,7 +148,7 @@ def test_batch_norm_linear_fusing(bias: bool): input_data = torch.randn(input_shape, dtype=torch.float32) out1 = og_module(input_data).detach().numpy() out2 = graph_module_out(input_data).detach().numpy() - assert np.allclose(out1, out2, atol=1.2e-7) + torch.testing.assert_close(out1, out2) @pytest.mark.parametrize( From 545539071ad504e64ea2f638c642d22b168f3b92 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Fri, 24 Oct 2025 11:43:36 -0700 Subject: [PATCH 2/2] Fix --- backends/nxp/tests/test_batch_norm_fusion.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/backends/nxp/tests/test_batch_norm_fusion.py b/backends/nxp/tests/test_batch_norm_fusion.py index a9646a147ea..fce11ce5aa2 100644 --- a/backends/nxp/tests/test_batch_norm_fusion.py +++ b/backends/nxp/tests/test_batch_norm_fusion.py @@ -105,7 +105,10 @@ def test_batch_norm_conv_fusing(bias: bool, input_shape: list[int]): og_nodes = list(program.graph.nodes) transformed_nodes = list(graph_module_out.graph.nodes) - assert any(node.target.__name__ == "batch_norm.default" for node in og_nodes) + assert any( + node.op == "call_function" and node.target.__name__ == "batch_norm.default" + for node in og_nodes + ) assert not any( node.op == "call_function" and "batch_norm" in node.target.__name__ @@ -137,7 +140,10 @@ def test_batch_norm_linear_fusing(bias: bool): og_nodes = list(og_module.graph.nodes) transformed_nodes = list(graph_module_out.graph.nodes) - assert any(node.target.__name__ == "linear.default" for node in og_nodes) + assert any( + node.op == "call_function" and node.target.__name__ == "linear.default" + for node in og_nodes + ) assert not any( node.op == "call_function" and "batch_norm" in node.target.__name__