From 4238f5cfb8466d9040f7a728dd429a9f12677f43 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Tue, 20 May 2025 13:56:14 -0700 Subject: [PATCH 1/9] migrate convert/prepare to torchao --- backends/example/example_quantizer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/backends/example/example_quantizer.py b/backends/example/example_quantizer.py index b9291e1d48e..c649bc5a0f4 100644 --- a/backends/example/example_quantizer.py +++ b/backends/example/example_quantizer.py @@ -18,7 +18,6 @@ Quantizer, ) - def get_uint8_tensor_spec(observer_or_fake_quant_ctr): return QuantizationSpec( dtype=torch.uint8, From 05067b05393843af43c53b7113dfea21dd0be414 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Tue, 20 May 2025 21:58:28 -0700 Subject: [PATCH 2/9] init --- .lintrunner.toml | 2 -- extension/llm/export/builder.py | 3 +-- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index d4cf2531ce1..25b2409bbf0 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -391,8 +391,6 @@ exclude_patterns = [ "backends/vulkan/quantizer/**", "backends/vulkan/test/**", "backends/cadence/aot/quantizer/**", - "backends/qualcomm/quantizer/**", - "examples/qualcomm/**", "backends/xnnpack/quantizer/**", "backends/xnnpack/test/**", "exir/tests/test_passes.py", diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index b0da14e965e..8563e6f6a05 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -35,11 +35,10 @@ from executorch.extension.llm.export.export_passes import RemoveRedundantTransposes from pytorch_tokenizers import get_tokenizer -from torch.ao.quantization.quantizer import Quantizer -from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer from torch.export import export_for_training, ExportedProgram from torch.nn.attention import SDPBackend from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantizer import ComposableQuantizer, Quantizer from torchao.utils import unwrap_tensor_subclass FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" From 6ae741d6c42df3caefb438dd0e046cc1ace3c009 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Tue, 20 May 2025 22:01:20 -0700 Subject: [PATCH 3/9] rebase --- backends/example/example_quantizer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/backends/example/example_quantizer.py b/backends/example/example_quantizer.py index c649bc5a0f4..b9291e1d48e 100644 --- a/backends/example/example_quantizer.py +++ b/backends/example/example_quantizer.py @@ -18,6 +18,7 @@ Quantizer, ) + def get_uint8_tensor_spec(observer_or_fake_quant_ctr): return QuantizationSpec( dtype=torch.uint8, From 2bc6207759bd243991c6b5f5df24284eb03c7917 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Wed, 21 May 2025 08:36:04 -0700 Subject: [PATCH 4/9] up --- .lintrunner.toml | 5 +++++ extension/llm/export/builder.py | 3 ++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index 25b2409bbf0..659254ef80c 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -391,6 +391,11 @@ exclude_patterns = [ "backends/vulkan/quantizer/**", "backends/vulkan/test/**", "backends/cadence/aot/quantizer/**", +<<<<<<< HEAD +======= + "backends/qualcomm/quantizer/**", + "examples/qualcomm/**", +>>>>>>> 362501568 (up) "backends/xnnpack/quantizer/**", "backends/xnnpack/test/**", "exir/tests/test_passes.py", diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index 8563e6f6a05..b0da14e965e 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -35,10 +35,11 @@ from executorch.extension.llm.export.export_passes import RemoveRedundantTransposes from pytorch_tokenizers import get_tokenizer +from torch.ao.quantization.quantizer import Quantizer +from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer from torch.export import export_for_training, ExportedProgram from torch.nn.attention import SDPBackend from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e -from torchao.quantization.pt2e.quantizer import ComposableQuantizer, Quantizer from torchao.utils import unwrap_tensor_subclass FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" From 1904fa88e4750e83a6a944fa20b18e610bfd2933 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Wed, 21 May 2025 11:16:06 -0700 Subject: [PATCH 5/9] init --- .lintrunner.toml | 5 -- backends/qualcomm/_passes/qnn_pass_manager.py | 4 +- backends/qualcomm/builders/node_visitor.py | 4 +- backends/qualcomm/partition/utils.py | 4 +- backends/qualcomm/quantizer/annotators.py | 55 +++++++++---------- .../qualcomm/quantizer/custom_annotation.py | 6 +- .../observers/per_block_param_observer.py | 4 +- .../observers/per_channel_param_observer.py | 2 +- backends/qualcomm/quantizer/qconfig.py | 11 ++-- backends/qualcomm/quantizer/quantizer.py | 7 +-- backends/qualcomm/tests/utils.py | 7 ++- examples/qualcomm/oss_scripts/llama/llama.py | 2 +- examples/qualcomm/oss_scripts/moshi/mimi.py | 2 +- examples/qualcomm/utils.py | 11 ++-- 14 files changed, 60 insertions(+), 64 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index 659254ef80c..25b2409bbf0 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -391,11 +391,6 @@ exclude_patterns = [ "backends/vulkan/quantizer/**", "backends/vulkan/test/**", "backends/cadence/aot/quantizer/**", -<<<<<<< HEAD -======= - "backends/qualcomm/quantizer/**", - "examples/qualcomm/**", ->>>>>>> 362501568 (up) "backends/xnnpack/quantizer/**", "backends/xnnpack/test/**", "exir/tests/test_passes.py", diff --git a/backends/qualcomm/_passes/qnn_pass_manager.py b/backends/qualcomm/_passes/qnn_pass_manager.py index 63c303eb689..9f668a89441 100644 --- a/backends/qualcomm/_passes/qnn_pass_manager.py +++ b/backends/qualcomm/_passes/qnn_pass_manager.py @@ -131,8 +131,8 @@ def get_to_edge_transform_passes( from executorch.backends.qualcomm._passes import utils from executorch.exir.dialects._ops import ops as exir_ops - utils.q_ops.add(exir_ops.edge.pt2e_quant.quantize_affine.default) - utils.dq_ops.add(exir_ops.edge.pt2e_quant.dequantize_affine.default) + utils.q_ops.add(exir_ops.edge.torchao.quantize_affine.default) + utils.dq_ops.add(exir_ops.edge.torchao.dequantize_affine.default) passes_job = ( passes_job if passes_job is not None else get_capture_program_passes() diff --git a/backends/qualcomm/builders/node_visitor.py b/backends/qualcomm/builders/node_visitor.py index 5e9520d4c05..3e0d2eaae2a 100644 --- a/backends/qualcomm/builders/node_visitor.py +++ b/backends/qualcomm/builders/node_visitor.py @@ -254,8 +254,8 @@ def get_quant_encoding_conf( ) # TODO: refactor this when target could be correctly detected per_block_encoding = { - exir_ops.edge.pt2e_quant.quantize_affine.default, - exir_ops.edge.pt2e_quant.dequantize_affine.default, + exir_ops.edge.torchao.quantize_affine.default, + exir_ops.edge.torchao.dequantize_affine.default, } if quant_attrs[QCOM_ENCODING] in per_block_encoding: return self.make_qnn_per_block_config(node, quant_attrs) diff --git a/backends/qualcomm/partition/utils.py b/backends/qualcomm/partition/utils.py index 816d1ac1d9b..05bbd1ff970 100644 --- a/backends/qualcomm/partition/utils.py +++ b/backends/qualcomm/partition/utils.py @@ -57,7 +57,7 @@ def get_skip_decomp_table() -> List[torch._ops.OperatorBase]: torch.ops.aten.upsample_bicubic2d.vec, # This request is ignored because it is in a blocklist. Refer to exir/program/_program.py torch.ops.aten.unbind.int, - torch.ops.pt2e_quant.quantize_affine.default, - torch.ops.pt2e_quant.dequantize_affine.default, + torch.ops.torchao.quantize_affine.default, + torch.ops.torchao.dequantize_affine.default, ] return do_not_decompose diff --git a/backends/qualcomm/quantizer/annotators.py b/backends/qualcomm/quantizer/annotators.py index 469a801feeb..8611549861d 100644 --- a/backends/qualcomm/quantizer/annotators.py +++ b/backends/qualcomm/quantizer/annotators.py @@ -12,20 +12,17 @@ from torch._ops import OpOverload from torch._subclasses import FakeTensor -from torch.ao.quantization.fake_quantize import FixedQParamsFakeQuantize +from torch.fx import Node -from torch.ao.quantization.observer import FixedQParamsObserver -from torch.ao.quantization.quantizer import ( +from torchao.quantization.pt2e import FixedQParamsFakeQuantize, FixedQParamsObserver +from torchao.quantization.pt2e.quantizer import ( + annotate_input_qspec_map, + annotate_output_qspec, DerivedQuantizationSpec, QuantizationAnnotation, QuantizationSpec, SharedQuantizationSpec, ) -from torch.ao.quantization.quantizer.utils import ( - _annotate_input_qspec_map, - _annotate_output_qspec, -) -from torch.fx import Node from .qconfig import ( get_16a16w_qnn_ptq_config, @@ -618,19 +615,19 @@ def annotate_rms_norm(node: Node, quantization_config: QuantizationConfig) -> No return # TODO current only support 16a16w - _annotate_input_qspec_map( + annotate_input_qspec_map( node, act_node, quantization_config.input_activation, ) - _annotate_input_qspec_map( + annotate_input_qspec_map( node, weight_node, quantization_config.input_activation, ) nodes_to_mark_annotated = [node] - _annotate_output_qspec(node, quantization_config.output_activation) + annotate_output_qspec(node, quantization_config.output_activation) _mark_nodes_as_annotated(nodes_to_mark_annotated) @@ -819,25 +816,25 @@ def annotate_group_norm(node: Node, quantization_config: QuantizationConfig) -> if _is_annotated([node]): return - _annotate_input_qspec_map( + annotate_input_qspec_map( node, act_node, quantization_config.input_activation, ) - _annotate_input_qspec_map( + annotate_input_qspec_map( node, weight_node, quantization_config.weight, ) nodes_to_mark_annotated = [node, weight_node] if bias_node: - _annotate_input_qspec_map( + annotate_input_qspec_map( node, bias_node, quantization_config.bias, ) nodes_to_mark_annotated.append(bias_node) - _annotate_output_qspec(node, quantization_config.output_activation) + annotate_output_qspec(node, quantization_config.output_activation) _mark_nodes_as_annotated(nodes_to_mark_annotated) @@ -1002,12 +999,12 @@ def annotate_linear(node: Node, quantization_config: QuantizationConfig) -> None if _is_annotated([node]): return - _annotate_input_qspec_map( + annotate_input_qspec_map( node, act_node, quantization_config.input_activation, ) - _annotate_input_qspec_map( + annotate_input_qspec_map( node, weight_node, quantization_config.weight, @@ -1018,9 +1015,9 @@ def annotate_linear(node: Node, quantization_config: QuantizationConfig) -> None bias_config = quantization_config.bias(node) else: bias_config = quantization_config.bias - _annotate_input_qspec_map(node, bias_node, bias_config) + annotate_input_qspec_map(node, bias_node, bias_config) nodes_to_mark_annotated.append(bias_node) - _annotate_output_qspec(node, quantization_config.output_activation) + annotate_output_qspec(node, quantization_config.output_activation) _mark_nodes_as_annotated(nodes_to_mark_annotated) # We use get_source_partition in pass, but it is the same source for MultiheadAttention, so we need to change its source_fn_stack. @@ -1038,14 +1035,14 @@ def annotate_batch_and_instance_norm( return annotated_args = [act] - _annotate_input_qspec_map( + annotate_input_qspec_map( node, act, quantization_config.input_activation, ) # QNN requires uint8 instead of int8 in 'weight' config if weight is not None: - _annotate_input_qspec_map( + annotate_input_qspec_map( node, weight, quantization_config.input_activation, @@ -1053,14 +1050,14 @@ def annotate_batch_and_instance_norm( annotated_args.append(weight) if bias is not None: - _annotate_input_qspec_map( + annotate_input_qspec_map( node, bias, quantization_config.bias, ) annotated_args.append(bias) - _annotate_output_qspec(node, quantization_config.output_activation) + annotate_output_qspec(node, quantization_config.output_activation) _mark_nodes_as_annotated([node, *annotated_args]) @@ -1070,7 +1067,7 @@ def annotate_getitem(node: Node, quantization_config: QuantizationConfig) -> Non return if _is_float_tensor(node): - _annotate_output_qspec(node, quantization_config.output_activation) + annotate_output_qspec(node, quantization_config.output_activation) _mark_nodes_as_annotated([node]) @@ -1086,32 +1083,32 @@ def annotate_layer_norm(node: Node, quantization_config: QuantizationConfig) -> return input_act_qspec = quantization_config.input_activation - _annotate_input_qspec_map( + annotate_input_qspec_map( node, act_node, input_act_qspec, ) if input_act_qspec.dtype == torch.int32: - _annotate_input_qspec_map( + annotate_input_qspec_map( node, weight_node, get_16a16w_qnn_ptq_config().weight, ) else: - _annotate_input_qspec_map( + annotate_input_qspec_map( node, weight_node, input_act_qspec, ) nodes_to_mark_annotated = [node, weight_node] if bias_node: - _annotate_input_qspec_map( + annotate_input_qspec_map( node, bias_node, quantization_config.bias, ) nodes_to_mark_annotated.append(bias_node) - _annotate_output_qspec(node, quantization_config.output_activation) + annotate_output_qspec(node, quantization_config.output_activation) _mark_nodes_as_annotated(nodes_to_mark_annotated) diff --git a/backends/qualcomm/quantizer/custom_annotation.py b/backends/qualcomm/quantizer/custom_annotation.py index bda91609f1c..0e06015ed91 100644 --- a/backends/qualcomm/quantizer/custom_annotation.py +++ b/backends/qualcomm/quantizer/custom_annotation.py @@ -17,13 +17,13 @@ QuantizationConfig, ) from executorch.exir.dialects._ops import ops as exir_ops -from torch.ao.quantization.observer import FixedQParamsObserver, MinMaxObserver -from torch.ao.quantization.quantizer import ( +from torch.fx import Node +from torchao.quantization.pt2e import FixedQParamsObserver, MinMaxObserver +from torchao.quantization.pt2e.quantizer import ( QuantizationAnnotation, QuantizationSpec, SharedQuantizationSpec, ) -from torch.fx import Node def annotate_mimi_decoder(gm: torch.fx.GraphModule): diff --git a/backends/qualcomm/quantizer/observers/per_block_param_observer.py b/backends/qualcomm/quantizer/observers/per_block_param_observer.py index e60f15c6d9c..802d5706d89 100644 --- a/backends/qualcomm/quantizer/observers/per_block_param_observer.py +++ b/backends/qualcomm/quantizer/observers/per_block_param_observer.py @@ -7,8 +7,8 @@ from typing import Tuple import torch -from torch.ao.quantization.observer import MappingType, PerBlock -from torch.ao.quantization.pt2e._affine_quantization import ( +from torchao.quantization.pt2e import MappingType, PerBlock +from torchao.quantization.pt2e._affine_quantization import ( _get_reduction_params, AffineQuantizedMinMaxObserver, choose_qparams_affine_with_min_max, diff --git a/backends/qualcomm/quantizer/observers/per_channel_param_observer.py b/backends/qualcomm/quantizer/observers/per_channel_param_observer.py index 3c04e620308..9f89f6b0e69 100644 --- a/backends/qualcomm/quantizer/observers/per_channel_param_observer.py +++ b/backends/qualcomm/quantizer/observers/per_channel_param_observer.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import torch -from torch.ao.quantization.observer import UniformQuantizationObserverBase +from torchao.quantization.pt2e import UniformQuantizationObserverBase # TODO move to torch/ao/quantization/observer.py. diff --git a/backends/qualcomm/quantizer/qconfig.py b/backends/qualcomm/quantizer/qconfig.py index 67968363eb6..e2a9cd83567 100644 --- a/backends/qualcomm/quantizer/qconfig.py +++ b/backends/qualcomm/quantizer/qconfig.py @@ -7,18 +7,19 @@ PerBlockParamObserver, ) from torch import Tensor -from torch.ao.quantization.fake_quantize import ( +from torch.fx import Node +from torchao.quantization.pt2e import ( FakeQuantize, FusedMovingAvgObsFakeQuantize, -) -from torch.ao.quantization.observer import ( MinMaxObserver, MovingAverageMinMaxObserver, MovingAveragePerChannelMinMaxObserver, PerChannelMinMaxObserver, ) -from torch.ao.quantization.quantizer import DerivedQuantizationSpec, QuantizationSpec -from torch.fx import Node +from torchao.quantization.pt2e.quantizer import ( + DerivedQuantizationSpec, + QuantizationSpec, +) @dataclass(eq=True) diff --git a/backends/qualcomm/quantizer/quantizer.py b/backends/qualcomm/quantizer/quantizer.py index 8e65607dd84..9a149e7db87 100644 --- a/backends/qualcomm/quantizer/quantizer.py +++ b/backends/qualcomm/quantizer/quantizer.py @@ -12,8 +12,9 @@ from executorch.backends.qualcomm._passes.qnn_pass_manager import QnnPassManager from torch._ops import OpOverload -from torch.ao.quantization.quantizer import Quantizer from torch.fx import GraphModule +from torchao.quantization.pt2e import UniformQuantizationObserverBase +from torchao.quantization.pt2e.quantizer import Quantizer from .annotators import OP_ANNOTATOR @@ -130,9 +131,7 @@ class ModuleQConfig: is_qat: bool = False is_conv_per_channel: bool = False is_linear_per_channel: bool = False - act_observer: Optional[ - torch.ao.quantization.observer.UniformQuantizationObserverBase - ] = None + act_observer: Optional[UniformQuantizationObserverBase] = None def __post_init__(self): if (self.quant_dtype, self.is_qat) not in QUANT_CONFIG_DICT: diff --git a/backends/qualcomm/tests/utils.py b/backends/qualcomm/tests/utils.py index 6432d67981a..24a8947e265 100644 --- a/backends/qualcomm/tests/utils.py +++ b/backends/qualcomm/tests/utils.py @@ -14,6 +14,7 @@ import numpy as np import torch +import torchao from executorch import exir from executorch.backends.qualcomm._passes.utils import dq_ops from executorch.backends.qualcomm.qnn_preprocess import QnnBackend @@ -537,8 +538,8 @@ def get_qdq_module( torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.quantized_decomposed.quantize_per_channel.default, torch.ops.quantized_decomposed.dequantize_per_channel.default, - torch.ops.pt2e_quant.quantize_affine.default, - torch.ops.pt2e_quant.dequantize_affine.default, + torch.ops.torchao.quantize_affine.default, + torch.ops.torchao.dequantize_affine.default, } if not bypass_check: self.assertTrue(nodes.intersection(q_and_dq)) @@ -569,7 +570,7 @@ def get_prepared_qat_module( quantizer.set_submodule_qconfig_list(submodule_qconfig_list) prepared = prepare_qat_pt2e(m, quantizer) - return torch.ao.quantization.move_exported_model_to_train(prepared) + return torchao.quantization.pt2e.move_exported_model_to_train(prepared) def get_converted_sgd_trained_module( self, diff --git a/examples/qualcomm/oss_scripts/llama/llama.py b/examples/qualcomm/oss_scripts/llama/llama.py index eb46c4fbd91..0dea2d12652 100755 --- a/examples/qualcomm/oss_scripts/llama/llama.py +++ b/examples/qualcomm/oss_scripts/llama/llama.py @@ -81,7 +81,7 @@ from pytorch_tokenizers import get_tokenizer, TiktokenTokenizer from pytorch_tokenizers.llama2c import Llama2cTokenizer as SentencePieceTokenizer -from torch.ao.quantization.observer import MinMaxObserver +from torchao.quantization.pt2e import MinMaxObserver from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e sys.setrecursionlimit(4096) diff --git a/examples/qualcomm/oss_scripts/moshi/mimi.py b/examples/qualcomm/oss_scripts/moshi/mimi.py index 6b59a71ae64..70e339a32d6 100644 --- a/examples/qualcomm/oss_scripts/moshi/mimi.py +++ b/examples/qualcomm/oss_scripts/moshi/mimi.py @@ -37,7 +37,7 @@ from huggingface_hub import hf_hub_download from moshi.models import loaders -from torch.ao.quantization.observer import MinMaxObserver +from torchao.quantization.pt2e import MinMaxObserver def seed_all(seed): diff --git a/examples/qualcomm/utils.py b/examples/qualcomm/utils.py index d0f3d4cb9ed..ec7c4c08d14 100755 --- a/examples/qualcomm/utils.py +++ b/examples/qualcomm/utils.py @@ -19,6 +19,7 @@ import numpy as np import torch +import torchao from executorch.backends.qualcomm.quantizer.quantizer import ( ModuleQConfig, QnnQuantizer, @@ -33,7 +34,7 @@ ) from executorch.exir.capture._config import ExecutorchBackendConfig from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass -from torch.ao.quantization.observer import MovingAverageMinMaxObserver +from torchao.quantization.pt2e import MovingAverageMinMaxObserver from torchao.quantization.pt2e.quantize_pt2e import ( convert_pt2e, prepare_pt2e, @@ -231,7 +232,7 @@ def ptq_calibrate(captured_model, quantizer, dataset): def qat_train(ori_model, captured_model, quantizer, dataset): data, targets = dataset - annotated_model = torch.ao.quantization.move_exported_model_to_train( + annotated_model = torchao.quantization.pt2e.move_exported_model_to_train( prepare_qat_pt2e(captured_model, quantizer) ) optimizer = torch.optim.SGD(annotated_model.parameters(), lr=0.00001) @@ -240,7 +241,9 @@ def qat_train(ori_model, captured_model, quantizer, dataset): print(f"Epoch {i}") if i > 3: # Freeze quantizer parameters - annotated_model.apply(torch.ao.quantization.disable_observer) + annotated_model.apply( + torchao.quantization.pt2e.fake_quantize.disable_observer + ) if i > 2: # Freeze batch norm mean and variance estimates annotated_model.apply(torch.nn.intrinsic.qat.freeze_bn_stats) @@ -252,7 +255,7 @@ def qat_train(ori_model, captured_model, quantizer, dataset): optimizer.step() return convert_pt2e( - torch.ao.quantization.move_exported_model_to_eval(annotated_model), + torchao.quantization.pt2e.move_exported_model_to_eval(annotated_model), ) From 5461630b414b6ac0999f51ae9b5097914d480f00 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Wed, 21 May 2025 14:30:13 -0700 Subject: [PATCH 6/9] up --- .lintrunner.toml | 1 - .../duplicate_dynamic_quant_chain.py | 10 ++++----- extension/llm/export/builder.py | 21 ++++++++++++++----- extension/llm/export/quantizer_lib.py | 2 +- 4 files changed, 21 insertions(+), 13 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index 25b2409bbf0..653e866ba98 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -397,7 +397,6 @@ exclude_patterns = [ "extension/llm/export/builder.py", "extension/llm/export/quantizer_lib.py", "exir/tests/test_memory_planning.py", - "backends/transforms/duplicate_dynamic_quant_chain.py", "exir/backend/test/demos/test_xnnpack_qnnpack.py", ] diff --git a/backends/transforms/duplicate_dynamic_quant_chain.py b/backends/transforms/duplicate_dynamic_quant_chain.py index 2ca65eec45f..6f75f14c188 100644 --- a/backends/transforms/duplicate_dynamic_quant_chain.py +++ b/backends/transforms/duplicate_dynamic_quant_chain.py @@ -9,14 +9,12 @@ import torch -from torch.ao.quantization.pt2e.utils import ( - _filter_sym_size_users, - _is_valid_annotation, -) - from torch.fx.node import map_arg from torch.fx.passes.infra.pass_base import PassBase, PassResult +from torchao.quantization.pt2e.quantizer import is_valid_annotation +from torchao.quantization.pt2e.utils import _filter_sym_size_users + logger = logging.getLogger(__name__) logger.setLevel(logging.WARNING) @@ -129,7 +127,7 @@ def _maybe_duplicate_dynamic_quantize_chain( dq_node_users = list(dq_node.users.copy()) for user in dq_node_users: annotation = user.meta.get("quantization_annotation", None) - if not _is_valid_annotation(annotation): + if not is_valid_annotation(annotation): return with gm.graph.inserting_after(dq_node): new_node = gm.graph.node_copy(dq_node) diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index b0da14e965e..b3646d9de2d 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -13,7 +13,7 @@ import contextlib import logging from enum import Enum -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple, Union from unittest.mock import patch import torch @@ -35,11 +35,15 @@ from executorch.extension.llm.export.export_passes import RemoveRedundantTransposes from pytorch_tokenizers import get_tokenizer -from torch.ao.quantization.quantizer import Quantizer -from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer +from torch.ao.quantization.quantizer import TorchQuantizer +from torch.ao.quantization.quantizer.composable_quantizer import ( + TorchComposableQuantizer, +) + from torch.export import export_for_training, ExportedProgram from torch.nn.attention import SDPBackend from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantizer import ComposableQuantizer, Quantizer from torchao.utils import unwrap_tensor_subclass FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" @@ -350,7 +354,9 @@ def calibrate_template( print(f"{task}: {res}") logging.info("Calibration finish...") - def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManager": + def pt2e_quantize( + self, quantizers: Optional[List[Union[Quantizer, TorchQuantizer]]] + ) -> "LLMEdgeManager": """ Quantize the model via pt2e flow and retrieve LLMEdgeManager including the quantized model. Args: @@ -367,7 +373,12 @@ def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManage with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(): if self.verbose: logging.info(f"Applied quantizers: {quantizers}") - composed_quantizer = ComposableQuantizer(quantizers) + + if any(isinstance(q, Quantizer) for q in quantizers): + composed_quantizer = ComposableQuantizer(quantizers) + else: + composed_quantizer = TorchComposableQuantizer(quantizers) + assert ( self.pre_autograd_graph_module is not None ), "Please run export() first" diff --git a/extension/llm/export/quantizer_lib.py b/extension/llm/export/quantizer_lib.py index d7b8b3a92b1..b199ab57ccb 100644 --- a/extension/llm/export/quantizer_lib.py +++ b/extension/llm/export/quantizer_lib.py @@ -154,7 +154,7 @@ def get_qnn_quantizer( QnnQuantizer, QuantDtype, ) - from torch.ao.quantization.observer import MinMaxObserver + from torchao.quantization.pt2e import MinMaxObserver except ImportError: raise ImportError( From f32f8a57bd4b897673bb0edfd2c5938afedeafb1 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Wed, 21 May 2025 14:43:07 -0700 Subject: [PATCH 7/9] up --- extension/llm/export/builder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index b3646d9de2d..dd96b32415b 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -35,9 +35,9 @@ from executorch.extension.llm.export.export_passes import RemoveRedundantTransposes from pytorch_tokenizers import get_tokenizer -from torch.ao.quantization.quantizer import TorchQuantizer +from torch.ao.quantization.quantizer import Quantizer as TorchQuantizer from torch.ao.quantization.quantizer.composable_quantizer import ( - TorchComposableQuantizer, + ComposableQuantizer as TorchComposableQuantizer, ) from torch.export import export_for_training, ExportedProgram From ebea29301a879b7d83b7412c96ac8db28f36067e Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Wed, 28 May 2025 14:00:55 -0700 Subject: [PATCH 8/9] up --- backends/qualcomm/_passes/qnn_pass_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backends/qualcomm/_passes/qnn_pass_manager.py b/backends/qualcomm/_passes/qnn_pass_manager.py index c2422180d7f..d324f6144a5 100644 --- a/backends/qualcomm/_passes/qnn_pass_manager.py +++ b/backends/qualcomm/_passes/qnn_pass_manager.py @@ -137,7 +137,7 @@ def get_to_edge_transform_passes( node_visitor.q_ops.add(exir_ops.edge.torchao.quantize_affine.default) node_visitor.dq_ops.add(exir_ops.edge.torchao.dequantize_affine.default) - + passes_job = ( passes_job if passes_job is not None else get_capture_program_passes() ) From 675835038daea0dfaa60b7616cf59f361998cb16 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Wed, 28 May 2025 15:30:37 -0700 Subject: [PATCH 9/9] up --- extension/llm/export/builder.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index 78e701749ad..d32b44246f6 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -35,6 +35,8 @@ from executorch.extension.llm.export.export_passes import RemoveRedundantTransposes from pytorch_tokenizers import get_tokenizer + +# TODO: remove these once pt2e migration from torch.ao to torchao is complete from torch.ao.quantization.quantizer import Quantizer as TorchQuantizer from torch.ao.quantization.quantizer.composable_quantizer import ( ComposableQuantizer as TorchComposableQuantizer, @@ -374,10 +376,14 @@ def pt2e_quantize( if self.verbose: logging.info(f"Applied quantizers: {quantizers}") - if any(isinstance(q, Quantizer) for q in quantizers): + if all(isinstance(q, Quantizer) for q in quantizers): composed_quantizer = ComposableQuantizer(quantizers) - else: + elif all(isinstance(q, TorchQuantizer) for q in quantizers): composed_quantizer = TorchComposableQuantizer(quantizers) + else: + raise ValueError( + "Quantizers must be either Quantizer or TorchQuantizer" + ) assert ( self.pre_autograd_graph_module is not None