diff --git a/backends/arm/test/misc/test_debug_feats.py b/backends/arm/test/misc/test_debug_feats.py index c2f28f4e9d8..302c5ab80a1 100644 --- a/backends/arm/test/misc/test_debug_feats.py +++ b/backends/arm/test/misc/test_debug_feats.py @@ -21,6 +21,7 @@ TosaPipelineFP, TosaPipelineINT, ) +from executorch.backends.test.harness.stages import StageType input_t1 = Tuple[torch.Tensor] # Input x @@ -104,7 +105,7 @@ def test_INT_artifact(test_data: input_t1): @common.parametrize("test_data", Linear.inputs) def test_numerical_diff_print(test_data: input_t1): - pipeline = TosaPipelineFP[input_t1]( + pipeline = TosaPipelineINT[input_t1]( Linear(), test_data, [], @@ -119,7 +120,9 @@ def test_numerical_diff_print(test_data: input_t1): # not present. try: # Tolerate 0 difference => we want to trigger a numerical diff - tester.run_method_and_compare_outputs(atol=0, rtol=0, qtol=0) + tester.run_method_and_compare_outputs( + stage=StageType.INITIAL_MODEL, atol=0, rtol=0, qtol=0 + ) except AssertionError: pass # Implicit pass test else: