diff --git a/exir/tests/test_quant_fusion_pass.py b/exir/tests/test_quant_fusion_pass.py index 3097f09c430..e3073197b2b 100644 --- a/exir/tests/test_quant_fusion_pass.py +++ b/exir/tests/test_quant_fusion_pass.py @@ -37,6 +37,7 @@ from torch.testing import FileCheck from torchao.quantization.granularity import PerAxis, PerGroup from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_ +from torchao.quantization.utils import compute_error class TestQuantFusionPass(unittest.TestCase): @@ -470,7 +471,8 @@ def _test_embedding_torchao( # Compare numerics actual_outputs = m.exported_program().module()(*example_inputs) - self.assertTrue(torch.allclose(expected_outputs, actual_outputs)) + sqnr = compute_error(expected_outputs, actual_outputs) + self.assertTrue(sqnr >= 50, f"Got sqnr {sqnr}") # Can lower to executorch exec_prog = m.to_executorch() # noqa @@ -488,7 +490,8 @@ def _test_embedding_torchao( ) actual_outputs2 = m_copy.exported_program().module()(*example_inputs) - self.assertTrue(torch.allclose(expected_outputs, actual_outputs2)) + sqnr = compute_error(expected_outputs, actual_outputs2) + self.assertTrue(sqnr >= 50, f"Got sqnr {sqnr}") # Can lower to executorch exec_prog2 = m_copy.to_executorch() # noqa