diff --git a/backends/nxp/tests/test_batch_norm_fusion.py b/backends/nxp/tests/test_batch_norm_fusion.py index 788d04c6dad..fce11ce5aa2 100644 --- a/backends/nxp/tests/test_batch_norm_fusion.py +++ b/backends/nxp/tests/test_batch_norm_fusion.py @@ -105,10 +105,11 @@ def test_batch_norm_conv_fusing(bias: bool, input_shape: list[int]): og_nodes = list(program.graph.nodes) transformed_nodes = list(graph_module_out.graph.nodes) - assert len(og_nodes) == (11 if bias else 10) - assert og_nodes[9 if bias else 8].target.__name__ == "batch_norm.default" + assert any( + node.op == "call_function" and node.target.__name__ == "batch_norm.default" + for node in og_nodes + ) - assert len(transformed_nodes) == 5 assert not any( node.op == "call_function" and "batch_norm" in node.target.__name__ for node in transformed_nodes @@ -118,7 +119,7 @@ def test_batch_norm_conv_fusing(bias: bool, input_shape: list[int]): input_data = torch.randn(input_shape, dtype=torch.float32) out1 = og_module(input_data).detach().numpy() out2 = graph_module_out(input_data).detach().numpy() - assert np.allclose(out1, out2, atol=3.0e-7) + torch.testing.assert_close(out1, out2) @pytest.mark.parametrize( @@ -139,10 +140,11 @@ def test_batch_norm_linear_fusing(bias: bool): og_nodes = list(og_module.graph.nodes) transformed_nodes = list(graph_module_out.graph.nodes) - assert len(og_nodes) == (11 if bias else 10) - assert og_nodes[8 if bias else 7].target.__name__ == "linear.default" + assert any( + node.op == "call_function" and node.target.__name__ == "linear.default" + for node in og_nodes + ) - assert len(transformed_nodes) == 5 assert not any( node.op == "call_function" and "batch_norm" in node.target.__name__ for node in transformed_nodes @@ -152,7 +154,7 @@ def test_batch_norm_linear_fusing(bias: bool): input_data = torch.randn(input_shape, dtype=torch.float32) out1 = og_module(input_data).detach().numpy() out2 = graph_module_out(input_data).detach().numpy() - assert np.allclose(out1, out2, atol=1.2e-7) + torch.testing.assert_close(out1, out2) @pytest.mark.parametrize(