From 69d4eb1adce36ff28ea5f35d8dcc7b07d7ce019e Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Fri, 23 May 2025 09:23:28 -0700 Subject: [PATCH] Add backward compatible types to pt2e prepare (#11080) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/11080 X-link: https://github.com/pytorch/ao/pull/2244 Differential Revision: D75248288 --- backends/cadence/aot/compiler.py | 2 +- examples/xnnpack/quantization/utils.py | 2 +- extension/llm/export/builder.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/backends/cadence/aot/compiler.py b/backends/cadence/aot/compiler.py index a54954a3e3c..438f07ba15f 100644 --- a/backends/cadence/aot/compiler.py +++ b/backends/cadence/aot/compiler.py @@ -123,7 +123,7 @@ def prepare_and_convert_pt2( assert isinstance(model_gm, torch.fx.GraphModule) # Prepare - prepared_model = prepare_pt2e(model_gm, quantizer) # pyre-ignore[6] + prepared_model = prepare_pt2e(model_gm, quantizer) # Calibrate # If no calibration data is provided, use the inputs diff --git a/examples/xnnpack/quantization/utils.py b/examples/xnnpack/quantization/utils.py index deb905ab405..d7648daf5da 100644 --- a/examples/xnnpack/quantization/utils.py +++ b/examples/xnnpack/quantization/utils.py @@ -33,7 +33,7 @@ def quantize( is_dynamic=is_dynamic, ) quantizer.set_global(operator_config) - m = prepare_pt2e(model, quantizer) # pyre-ignore[6] + m = prepare_pt2e(model, quantizer) # calibration m(*example_inputs) m = convert_pt2e(m) diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index b0da14e965e..feae9f45861 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -373,7 +373,7 @@ def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManage ), "Please run export() first" m = prepare_pt2e( self.pre_autograd_graph_module, # pyre-ignore[6] - composed_quantizer, # pyre-ignore[6] + composed_quantizer, ) logging.info( f"Calibrating with tasks: {self.calibration_tasks}, limit: {self.calibration_limit}, calibration_data: {self.calibration_data}, tokenizer_path: {self.tokenizer_path}, seq_length: {self.calibration_seq_length}"