From 7e4d6d0375cd08fb6aca58695c2bede9216025a3 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Wed, 24 Sep 2025 11:04:15 -0700 Subject: [PATCH] init --- exir/tests/test_quant_fusion_pass.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/exir/tests/test_quant_fusion_pass.py b/exir/tests/test_quant_fusion_pass.py index 3097f09c430..e3073197b2b 100644 --- a/exir/tests/test_quant_fusion_pass.py +++ b/exir/tests/test_quant_fusion_pass.py @@ -37,6 +37,7 @@ from torch.testing import FileCheck from torchao.quantization.granularity import PerAxis, PerGroup from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_ +from torchao.quantization.utils import compute_error class TestQuantFusionPass(unittest.TestCase): @@ -470,7 +471,8 @@ def _test_embedding_torchao( # Compare numerics actual_outputs = m.exported_program().module()(*example_inputs) - self.assertTrue(torch.allclose(expected_outputs, actual_outputs)) + sqnr = compute_error(expected_outputs, actual_outputs) + self.assertTrue(sqnr >= 50, f"Got sqnr {sqnr}") # Can lower to executorch exec_prog = m.to_executorch() # noqa @@ -488,7 +490,8 @@ def _test_embedding_torchao( ) actual_outputs2 = m_copy.exported_program().module()(*example_inputs) - self.assertTrue(torch.allclose(expected_outputs, actual_outputs2)) + sqnr = compute_error(expected_outputs, actual_outputs2) + self.assertTrue(sqnr >= 50, f"Got sqnr {sqnr}") # Can lower to executorch exec_prog2 = m_copy.to_executorch() # noqa