From 3360b258832f93c309fa2be391a5de02b6698d86 Mon Sep 17 00:00:00 2001 From: Hardik Sharma Date: Tue, 24 Mar 2026 17:21:03 -0700 Subject: [PATCH] Fix flaky ReplaceTrivialConvWithLinear pass validation tolerance (#18482) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: The `test_replace_conv2d_with_linear` and `test_replace_conv1d_with_linear` tests validate that replacing trivial convolutions with linear ops produces numerically equivalent outputs. Both operations compute the same dot product (sum of element-wise products), but conv accumulates across spatial dimensions (C,H,W) while linear accumulates over a flattened K dimension. With K=294 (conv2d: 6*7*7) or K=672 (conv1d: 96*7) fp32 terms, different accumulation orders produce diffs up to ~1.2e-05 due to non-associativity of floating-point addition. This is not a correctness issue — the mathematical operation is identical. Relax rtol from 1e-05 to 2e-05 to accommodate fp32 accumulation order differences while remaining tight enough to catch real bugs. Reviewed By: DrJessop Differential Revision: D98001101 --- .../aot/tests/test_replace_ops_passes.py | 24 +++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/backends/cadence/aot/tests/test_replace_ops_passes.py b/backends/cadence/aot/tests/test_replace_ops_passes.py index 95d470644a0..cecf49c58ce 100644 --- a/backends/cadence/aot/tests/test_replace_ops_passes.py +++ b/backends/cadence/aot/tests/test_replace_ops_passes.py @@ -1242,9 +1242,17 @@ def test_replace_conv1d_with_linear(self) -> None: self.assertTrue(result.modified) graph_after_passes = result.graph_module - # Validate numerical accuracy + # Conv and linear compute the same dot product but accumulate fp32 + # terms in different order, so non-associativity of floating-point + # addition produces diffs up to ~1.2e-05. Use rtol=2e-05. inputs = [x, weights, bias] - validate(gm_before, graph_after_passes, inputs, "ReplaceTrivialConvWithLinear") + validate( + gm_before, + graph_after_passes, + inputs, + "ReplaceTrivialConvWithLinear", + rtol=2e-5, + ) # Assert that conv1d is trivially converted to linear self.assertEqual( @@ -1278,9 +1286,17 @@ def test_replace_conv2d_with_linear(self) -> None: self.assertTrue(result.modified) graph_after_passes = result.graph_module - # Validate numerical accuracy + # Conv and linear compute the same dot product but accumulate fp32 + # terms in different order, so non-associativity of floating-point + # addition produces diffs up to ~1.2e-05. Use rtol=2e-05. inputs = [x, weights, bias] - validate(gm_before, graph_after_passes, inputs, "ReplaceTrivialConvWithLinear") + validate( + gm_before, + graph_after_passes, + inputs, + "ReplaceTrivialConvWithLinear", + rtol=2e-5, + ) # Assert that conv2d is trivially converted to linear self.assertEqual(