From 9f124952da73d2806816876d60364ace7fe9922c Mon Sep 17 00:00:00 2001 From: Abhinay Kukkadapu Date: Wed, 10 Sep 2025 17:05:42 -0700 Subject: [PATCH] Fix bug where source transform passes untransformed model to next stage (#14186) Summary: `quantize_()` modifies model in place, we need to make a copy to avoid making changes to user passed model. Fix a bug as discussed in https://github.com/pytorch/executorch/pull/14171#discussion_r2338013170 Differential Revision: D82167495 --- .../apple/coreml/test/test_coreml_recipes.py | 2 +- .../test/recipes/test_xnnpack_recipes.py | 2 +- export/stages.py | 2 +- export/tests/test_export_stages.py | 27 ++++++++++++++++--- 4 files changed, 26 insertions(+), 7 deletions(-) diff --git a/backends/apple/coreml/test/test_coreml_recipes.py b/backends/apple/coreml/test/test_coreml_recipes.py index f326a8879a4..313e24922d6 100644 --- a/backends/apple/coreml/test/test_coreml_recipes.py +++ b/backends/apple/coreml/test/test_coreml_recipes.py @@ -166,7 +166,7 @@ def forward(self, x): session, example_inputs, atol=1e-3 ) self._compare_eager_unquantized_model_outputs( - session, model, example_inputs + session, model, example_inputs, sqnr_threshold=15 ) def test_int4_weight_only_per_group_validation(self): diff --git a/backends/xnnpack/test/recipes/test_xnnpack_recipes.py b/backends/xnnpack/test/recipes/test_xnnpack_recipes.py index aa470bdcb50..e4a469418cd 100644 --- a/backends/xnnpack/test/recipes/test_xnnpack_recipes.py +++ b/backends/xnnpack/test/recipes/test_xnnpack_recipes.py @@ -154,7 +154,7 @@ def forward(self, x) -> torch.Tensor: ), ExportRecipe.get_recipe( XNNPackRecipeType.TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR, - group_size=8, + group_size=32, ), ] diff --git a/export/stages.py b/export/stages.py index 609e7d197b9..323b327bfa4 100644 --- a/export/stages.py +++ b/export/stages.py @@ -332,7 +332,7 @@ def run(self, artifact: PipelineArtifact) -> None: self._transformed_models = copy.deepcopy(artifact.data) # Apply torchao quantize_ to each model - for _, model in artifact.data.items(): + for _, model in self._transformed_models.items(): # pyre-ignore if len(self._quantization_recipe.ao_quantization_configs) > 1: raise ValueError( diff --git a/export/tests/test_export_stages.py b/export/tests/test_export_stages.py index 608aa5adb3c..4e8144bd487 100644 --- a/export/tests/test_export_stages.py +++ b/export/tests/test_export_stages.py @@ -300,11 +300,26 @@ def test_run_with_ao_quantization_configs( artifact = PipelineArtifact(data=models_dict, context={}) stage.run(artifact) - # Verify quantize_ was called with the model and config - mock_quantize.assert_called_once_with(self.model, mock_config, mock_filter_fn) + # Verify quantize_ was called once (with the copied model, not the original) + self.assertEqual(mock_quantize.call_count, 1) + # Verify the config and filter_fn arguments are correct + call_args = mock_quantize.call_args[0] + self.assertNotEqual(self.model, call_args[0]) + self.assertEqual(call_args[1], mock_config) + self.assertEqual(call_args[2], mock_filter_fn) - # Verify unwrap_tensor_subclass was called with the model - mock_unwrap.assert_called_once_with(self.model) + # Verify unwrap_tensor_subclass was called once (with the copied model) + self.assertEqual(mock_unwrap.call_count, 1) + + # Verify that the original models_dict is unchanged + self.assertEqual(models_dict, {"forward": self.model}) + + # Verify that the result artifact data contains valid models + result_artifact = stage.get_artifacts() + self.assertIn("forward", result_artifact.data) + self.assertIsNotNone(result_artifact.data["forward"]) + # verify the result model is NOT the same object as the original + self.assertIsNot(result_artifact.data["forward"], self.model) class TestQuantizeStage(unittest.TestCase): @@ -398,6 +413,10 @@ def test_run_with_quantizers( self.assertIn("forward", result_artifact.data) self.assertEqual(result_artifact.data["forward"], mock_quantized_model) + # Verify that the original model in the input artifact is unchanged + self.assertEqual(artifact.data["forward"], self.model) + self.assertIsNot(result_artifact.data["forward"], self.model) + def test_run_empty_example_inputs(self) -> None: """Test error when example inputs list is empty.""" mock_quantizer = Mock()