diff --git a/backends/xnnpack/test/passes/test_propagate_custom_meta_pass.py b/backends/xnnpack/test/passes/test_propagate_custom_meta_pass.py index f76ee3a0868..2d876b372cb 100644 --- a/backends/xnnpack/test/passes/test_propagate_custom_meta_pass.py +++ b/backends/xnnpack/test/passes/test_propagate_custom_meta_pass.py @@ -20,9 +20,15 @@ ) from executorch.backends.xnnpack.test.tester import Quantize as XNNPackQuantize, Tester from executorch.backends.xnnpack.test.tester.tester import ToEdgeTransformAndLower + +from executorch.exir import ExecutorchProgramManager +from executorch.exir._serialize import _deserialize_pte_binary from executorch.exir.passes.external_constants_pass import ( delegate_external_constants_pass_unlifted, ) +from executorch.extension.flat_tensor.serialize.serialize import ( + _deserialize_to_flat_tensor, +) from torchao.quantization.granularity import PerGroup from torchao.quantization.quant_api import Int8DynamicActivationIntxWeightConfig @@ -87,7 +93,7 @@ def _test_linear( self, partitioner: XnnpackPartitioner, quantization_stage: Union[BaseStages.Quantize, BaseStages.Quantize_], - ): + ) -> ExecutorchProgramManager: eager_model = self.ModuleLinear( in_size=1, input_channels=32, @@ -106,7 +112,7 @@ def _test_linear( exec = tester.get_artifact() program_buffer = exec.buffer self.assertEqual(len(exec._tensor_data), 1) - data_buffer = bytes(exec._tensor_data.pop("model")) + data_buffer = bytes(exec._tensor_data["model"]) self.assertTrue(len(data_buffer) > 200) from executorch.extension.pybindings import portable_lib as runtime @@ -122,6 +128,8 @@ def _test_linear( # test_inputs # ) + return exec + def test_quantize_(self): # Quantize with torchao quantize_ API. DynamicallyQuantizedPartitioner = XnnpackPartitioner( @@ -132,9 +140,16 @@ def test_quantize_(self): weight_dtype=torch.int4, weight_granularity=PerGroup(32), ) - self._test_linear( + exec = self._test_linear( DynamicallyQuantizedPartitioner, BaseStages.Quantize_(config=linear_config) ) + # PTE file has no named data. + pte_file = _deserialize_pte_binary(exec.buffer) + self.assertEqual(pte_file.named_data, None) + + # PTD file contains quantized weight and scale. + ptd_file = _deserialize_to_flat_tensor(bytes(exec._tensor_data["model"])) + self.assertEqual(len(ptd_file.named_data), 2) def test_pt2e_quantize(self): # Quantize with pt2e quantize. @@ -156,6 +171,15 @@ def test_pt2e_quantize(self): partitioner = XnnpackPartitioner( config_precisions=precision, per_op_mode=per_op_mode ) - self._test_linear( + exec = self._test_linear( partitioner, XNNPackQuantize(quantization_config=quant_config) ) + # PTE file has no named data. + pte_file = _deserialize_pte_binary(exec.buffer) + self.assertEqual(pte_file.named_data, None) + + # PTD file contains quantized weight, and potentially scale. + ptd_file = _deserialize_to_flat_tensor( + bytes(exec._tensor_data["model"]) + ) + self.assertTrue(len(ptd_file.named_data) >= 1) diff --git a/examples/models/llama/source_transformation/quantize.py b/examples/models/llama/source_transformation/quantize.py index b9076f90795..d60d90bad4b 100644 --- a/examples/models/llama/source_transformation/quantize.py +++ b/examples/models/llama/source_transformation/quantize.py @@ -179,9 +179,6 @@ def filter_fn(m, fqn): ), filter_fn=filter_fn, ) - - model = unwrap_tensor_subclass(model) - # TODO: deal with checkpoint / computation dtype decoupling. if verbose: diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index 675c0179ebb..ae15dded91d 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -38,7 +38,6 @@ from torch.nn.attention import SDPBackend from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e from torchao.quantization.pt2e.quantizer import ComposableQuantizer, Quantizer -from torchao.utils import unwrap_tensor_subclass FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) @@ -137,15 +136,15 @@ def __init__( if not self.dynamic_shapes and self.enable_dynamic_shape: if not self.use_kv_cache: # Only one input argument: tokens - # Here we -1 due to export limitation: https://gist.github.com/larryliu0820/419022a57e24d5e64150e325a685eaad + # Here we use -1 due to export limitation: https://gist.github.com/larryliu0820/419022a57e24d5e64150e325a685eaad self.dynamic_shapes = ( {1: torch.export.Dim("token_dim", max=self.max_seq_len - 1)}, ) else: # Two input arguments: tokens and input_pos but input_pos is static shape. - + # Here we use -1 due to export limitation (same as non-kv-cache case above). self.dynamic_shapes = ( - {1: torch.export.Dim("token_dim", max=self.max_seq_len)}, + {1: torch.export.Dim("token_dim", max=self.max_seq_len - 1)}, {"input_pos": {0: 1}}, ) @@ -203,11 +202,6 @@ def _get_edge_config(self) -> EdgeCompileConfig: return edge_config def _export(self, module: Optional[torch.nn.Module] = None) -> ExportedProgram: - if module is not None: - unwrap_tensor_subclass(module) - else: - unwrap_tensor_subclass(self.model) - dynamic_shape = self._get_dynamic_shape() # 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing # 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up) @@ -226,6 +220,8 @@ def _export(self, module: Optional[torch.nn.Module] = None) -> ExportedProgram: dynamic_shapes=dynamic_shape, strict=True, ) + # Functionalize the graph, and decompose subclasses from torchao quantize. + exported_module = exported_module.run_decompositions({}) return exported_module def export(self) -> "LLMEdgeManager": diff --git a/extension/llm/export/test/test_builder.py b/extension/llm/export/test/test_builder.py index 8bf591813ec..0f552460d78 100644 --- a/extension/llm/export/test/test_builder.py +++ b/extension/llm/export/test/test_builder.py @@ -88,7 +88,8 @@ def test_get_dynamic_shape_with_dynamic_shape_enabled_with_kv_cache(self) -> Non # Check first element (tokens dimension) self.assertIsInstance(result[0], dict) self.assertIn(1, result[0]) - self.assertEqual(result[0][1].max, self.max_seq_len) + # max is max_seq_len - 1 due to export limitation + self.assertEqual(result[0][1].max, self.max_seq_len - 1) # Check second element (input_pos dimension) self.assertIsInstance(result[1], dict)