diff --git a/.lintrunner.toml b/.lintrunner.toml index 503b5244ecd..b2603de8323 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -388,15 +388,12 @@ exclude_patterns = [ # backends "backends/vulkan/quantizer/**", "backends/vulkan/test/**", - "backends/qualcomm/quantizer/**", - "examples/qualcomm/**", "backends/xnnpack/quantizer/**", "backends/xnnpack/test/**", "exir/tests/test_passes.py", "extension/llm/export/builder.py", "extension/llm/export/quantizer_lib.py", "exir/tests/test_memory_planning.py", - "backends/transforms/duplicate_dynamic_quant_chain.py", "exir/backend/test/demos/test_xnnpack_qnnpack.py", ] diff --git a/backends/qualcomm/_passes/qnn_pass_manager.py b/backends/qualcomm/_passes/qnn_pass_manager.py index dcb35a9bc6c..d324f6144a5 100644 --- a/backends/qualcomm/_passes/qnn_pass_manager.py +++ b/backends/qualcomm/_passes/qnn_pass_manager.py @@ -135,8 +135,8 @@ def get_to_edge_transform_passes( from executorch.backends.qualcomm.builders import node_visitor from executorch.exir.dialects._ops import ops as exir_ops - node_visitor.q_ops.add(exir_ops.edge.pt2e_quant.quantize_affine.default) - node_visitor.dq_ops.add(exir_ops.edge.pt2e_quant.dequantize_affine.default) + node_visitor.q_ops.add(exir_ops.edge.torchao.quantize_affine.default) + node_visitor.dq_ops.add(exir_ops.edge.torchao.dequantize_affine.default) passes_job = ( passes_job if passes_job is not None else get_capture_program_passes() diff --git a/backends/qualcomm/builders/node_visitor.py b/backends/qualcomm/builders/node_visitor.py index d804ba6b6f0..60df273a6f2 100644 --- a/backends/qualcomm/builders/node_visitor.py +++ b/backends/qualcomm/builders/node_visitor.py @@ -265,8 +265,8 @@ def get_quant_encoding_conf( ) # TODO: refactor this when target could be correctly detected per_block_encoding = { - exir_ops.edge.pt2e_quant.quantize_affine.default, - exir_ops.edge.pt2e_quant.dequantize_affine.default, + exir_ops.edge.torchao.quantize_affine.default, + exir_ops.edge.torchao.dequantize_affine.default, } if quant_attrs[QCOM_ENCODING] in per_block_encoding: return self.make_qnn_per_block_config(node, quant_attrs) diff --git a/backends/qualcomm/partition/utils.py b/backends/qualcomm/partition/utils.py index 816d1ac1d9b..05bbd1ff970 100644 --- a/backends/qualcomm/partition/utils.py +++ b/backends/qualcomm/partition/utils.py @@ -57,7 +57,7 @@ def get_skip_decomp_table() -> List[torch._ops.OperatorBase]: torch.ops.aten.upsample_bicubic2d.vec, # This request is ignored because it is in a blocklist. Refer to exir/program/_program.py torch.ops.aten.unbind.int, - torch.ops.pt2e_quant.quantize_affine.default, - torch.ops.pt2e_quant.dequantize_affine.default, + torch.ops.torchao.quantize_affine.default, + torch.ops.torchao.dequantize_affine.default, ] return do_not_decompose diff --git a/backends/qualcomm/quantizer/annotators.py b/backends/qualcomm/quantizer/annotators.py index 7cf661a0e01..730bdaf47d0 100644 --- a/backends/qualcomm/quantizer/annotators.py +++ b/backends/qualcomm/quantizer/annotators.py @@ -12,20 +12,17 @@ from torch._ops import OpOverload from torch._subclasses import FakeTensor -from torch.ao.quantization.fake_quantize import FixedQParamsFakeQuantize +from torch.fx import Node -from torch.ao.quantization.observer import FixedQParamsObserver -from torch.ao.quantization.quantizer import ( +from torchao.quantization.pt2e import FixedQParamsFakeQuantize, FixedQParamsObserver +from torchao.quantization.pt2e.quantizer import ( + annotate_input_qspec_map, + annotate_output_qspec, DerivedQuantizationSpec, QuantizationAnnotation, QuantizationSpec, SharedQuantizationSpec, ) -from torch.ao.quantization.quantizer.utils import ( - _annotate_input_qspec_map, - _annotate_output_qspec, -) -from torch.fx import Node from .qconfig import ( get_16a16w_qnn_ptq_config, @@ -643,19 +640,19 @@ def annotate_rms_norm(node: Node, quantization_config: QuantizationConfig) -> No return # TODO current only support 16a16w - _annotate_input_qspec_map( + annotate_input_qspec_map( node, act_node, quantization_config.input_activation, ) - _annotate_input_qspec_map( + annotate_input_qspec_map( node, weight_node, quantization_config.input_activation, ) nodes_to_mark_annotated = [node] - _annotate_output_qspec(node, quantization_config.output_activation) + annotate_output_qspec(node, quantization_config.output_activation) _mark_nodes_as_annotated(nodes_to_mark_annotated) @@ -844,25 +841,25 @@ def annotate_group_norm(node: Node, quantization_config: QuantizationConfig) -> if _is_annotated([node]): return - _annotate_input_qspec_map( + annotate_input_qspec_map( node, act_node, quantization_config.input_activation, ) - _annotate_input_qspec_map( + annotate_input_qspec_map( node, weight_node, quantization_config.weight, ) nodes_to_mark_annotated = [node, weight_node] if bias_node: - _annotate_input_qspec_map( + annotate_input_qspec_map( node, bias_node, quantization_config.bias, ) nodes_to_mark_annotated.append(bias_node) - _annotate_output_qspec(node, quantization_config.output_activation) + annotate_output_qspec(node, quantization_config.output_activation) _mark_nodes_as_annotated(nodes_to_mark_annotated) @@ -1027,12 +1024,12 @@ def annotate_linear(node: Node, quantization_config: QuantizationConfig) -> None if _is_annotated([node]): return - _annotate_input_qspec_map( + annotate_input_qspec_map( node, act_node, quantization_config.input_activation, ) - _annotate_input_qspec_map( + annotate_input_qspec_map( node, weight_node, quantization_config.weight, @@ -1043,9 +1040,9 @@ def annotate_linear(node: Node, quantization_config: QuantizationConfig) -> None bias_config = quantization_config.bias(node) else: bias_config = quantization_config.bias - _annotate_input_qspec_map(node, bias_node, bias_config) + annotate_input_qspec_map(node, bias_node, bias_config) nodes_to_mark_annotated.append(bias_node) - _annotate_output_qspec(node, quantization_config.output_activation) + annotate_output_qspec(node, quantization_config.output_activation) _mark_nodes_as_annotated(nodes_to_mark_annotated) # We use get_source_partition in pass, but it is the same source for MultiheadAttention, so we need to change its source_fn_stack. @@ -1063,14 +1060,14 @@ def annotate_batch_and_instance_norm( return annotated_args = [act] - _annotate_input_qspec_map( + annotate_input_qspec_map( node, act, quantization_config.input_activation, ) # QNN requires uint8 instead of int8 in 'weight' config if weight is not None: - _annotate_input_qspec_map( + annotate_input_qspec_map( node, weight, quantization_config.input_activation, @@ -1078,14 +1075,14 @@ def annotate_batch_and_instance_norm( annotated_args.append(weight) if bias is not None: - _annotate_input_qspec_map( + annotate_input_qspec_map( node, bias, quantization_config.bias, ) annotated_args.append(bias) - _annotate_output_qspec(node, quantization_config.output_activation) + annotate_output_qspec(node, quantization_config.output_activation) _mark_nodes_as_annotated([node, *annotated_args]) @@ -1095,7 +1092,7 @@ def annotate_getitem(node: Node, quantization_config: QuantizationConfig) -> Non return if _is_float_tensor(node): - _annotate_output_qspec(node, quantization_config.output_activation) + annotate_output_qspec(node, quantization_config.output_activation) _mark_nodes_as_annotated([node]) @@ -1111,32 +1108,32 @@ def annotate_layer_norm(node: Node, quantization_config: QuantizationConfig) -> return input_act_qspec = quantization_config.input_activation - _annotate_input_qspec_map( + annotate_input_qspec_map( node, act_node, input_act_qspec, ) if input_act_qspec.dtype == torch.int32: - _annotate_input_qspec_map( + annotate_input_qspec_map( node, weight_node, get_16a16w_qnn_ptq_config().weight, ) else: - _annotate_input_qspec_map( + annotate_input_qspec_map( node, weight_node, input_act_qspec, ) nodes_to_mark_annotated = [node, weight_node] if bias_node: - _annotate_input_qspec_map( + annotate_input_qspec_map( node, bias_node, quantization_config.bias, ) nodes_to_mark_annotated.append(bias_node) - _annotate_output_qspec(node, quantization_config.output_activation) + annotate_output_qspec(node, quantization_config.output_activation) _mark_nodes_as_annotated(nodes_to_mark_annotated) diff --git a/backends/qualcomm/quantizer/custom_annotation.py b/backends/qualcomm/quantizer/custom_annotation.py index bda91609f1c..0e06015ed91 100644 --- a/backends/qualcomm/quantizer/custom_annotation.py +++ b/backends/qualcomm/quantizer/custom_annotation.py @@ -17,13 +17,13 @@ QuantizationConfig, ) from executorch.exir.dialects._ops import ops as exir_ops -from torch.ao.quantization.observer import FixedQParamsObserver, MinMaxObserver -from torch.ao.quantization.quantizer import ( +from torch.fx import Node +from torchao.quantization.pt2e import FixedQParamsObserver, MinMaxObserver +from torchao.quantization.pt2e.quantizer import ( QuantizationAnnotation, QuantizationSpec, SharedQuantizationSpec, ) -from torch.fx import Node def annotate_mimi_decoder(gm: torch.fx.GraphModule): diff --git a/backends/qualcomm/quantizer/observers/per_block_param_observer.py b/backends/qualcomm/quantizer/observers/per_block_param_observer.py index e60f15c6d9c..802d5706d89 100644 --- a/backends/qualcomm/quantizer/observers/per_block_param_observer.py +++ b/backends/qualcomm/quantizer/observers/per_block_param_observer.py @@ -7,8 +7,8 @@ from typing import Tuple import torch -from torch.ao.quantization.observer import MappingType, PerBlock -from torch.ao.quantization.pt2e._affine_quantization import ( +from torchao.quantization.pt2e import MappingType, PerBlock +from torchao.quantization.pt2e._affine_quantization import ( _get_reduction_params, AffineQuantizedMinMaxObserver, choose_qparams_affine_with_min_max, diff --git a/backends/qualcomm/quantizer/observers/per_channel_param_observer.py b/backends/qualcomm/quantizer/observers/per_channel_param_observer.py index 3c04e620308..9f89f6b0e69 100644 --- a/backends/qualcomm/quantizer/observers/per_channel_param_observer.py +++ b/backends/qualcomm/quantizer/observers/per_channel_param_observer.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import torch -from torch.ao.quantization.observer import UniformQuantizationObserverBase +from torchao.quantization.pt2e import UniformQuantizationObserverBase # TODO move to torch/ao/quantization/observer.py. diff --git a/backends/qualcomm/quantizer/qconfig.py b/backends/qualcomm/quantizer/qconfig.py index 67968363eb6..e2a9cd83567 100644 --- a/backends/qualcomm/quantizer/qconfig.py +++ b/backends/qualcomm/quantizer/qconfig.py @@ -7,18 +7,19 @@ PerBlockParamObserver, ) from torch import Tensor -from torch.ao.quantization.fake_quantize import ( +from torch.fx import Node +from torchao.quantization.pt2e import ( FakeQuantize, FusedMovingAvgObsFakeQuantize, -) -from torch.ao.quantization.observer import ( MinMaxObserver, MovingAverageMinMaxObserver, MovingAveragePerChannelMinMaxObserver, PerChannelMinMaxObserver, ) -from torch.ao.quantization.quantizer import DerivedQuantizationSpec, QuantizationSpec -from torch.fx import Node +from torchao.quantization.pt2e.quantizer import ( + DerivedQuantizationSpec, + QuantizationSpec, +) @dataclass(eq=True) diff --git a/backends/qualcomm/quantizer/quantizer.py b/backends/qualcomm/quantizer/quantizer.py index 8e65607dd84..9a149e7db87 100644 --- a/backends/qualcomm/quantizer/quantizer.py +++ b/backends/qualcomm/quantizer/quantizer.py @@ -12,8 +12,9 @@ from executorch.backends.qualcomm._passes.qnn_pass_manager import QnnPassManager from torch._ops import OpOverload -from torch.ao.quantization.quantizer import Quantizer from torch.fx import GraphModule +from torchao.quantization.pt2e import UniformQuantizationObserverBase +from torchao.quantization.pt2e.quantizer import Quantizer from .annotators import OP_ANNOTATOR @@ -130,9 +131,7 @@ class ModuleQConfig: is_qat: bool = False is_conv_per_channel: bool = False is_linear_per_channel: bool = False - act_observer: Optional[ - torch.ao.quantization.observer.UniformQuantizationObserverBase - ] = None + act_observer: Optional[UniformQuantizationObserverBase] = None def __post_init__(self): if (self.quant_dtype, self.is_qat) not in QUANT_CONFIG_DICT: diff --git a/backends/qualcomm/tests/utils.py b/backends/qualcomm/tests/utils.py index e89328904a1..173590b2a63 100644 --- a/backends/qualcomm/tests/utils.py +++ b/backends/qualcomm/tests/utils.py @@ -14,6 +14,7 @@ import numpy as np import torch +import torchao from executorch import exir from executorch.backends.qualcomm.builders.node_visitor import dq_ops from executorch.backends.qualcomm.qnn_preprocess import QnnBackend @@ -537,8 +538,8 @@ def get_qdq_module( torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.quantized_decomposed.quantize_per_channel.default, torch.ops.quantized_decomposed.dequantize_per_channel.default, - torch.ops.pt2e_quant.quantize_affine.default, - torch.ops.pt2e_quant.dequantize_affine.default, + torch.ops.torchao.quantize_affine.default, + torch.ops.torchao.dequantize_affine.default, } if not bypass_check: self.assertTrue(nodes.intersection(q_and_dq)) @@ -569,7 +570,7 @@ def get_prepared_qat_module( quantizer.set_submodule_qconfig_list(submodule_qconfig_list) prepared = prepare_qat_pt2e(m, quantizer) - return torch.ao.quantization.move_exported_model_to_train(prepared) + return torchao.quantization.pt2e.move_exported_model_to_train(prepared) def get_converted_sgd_trained_module( self, diff --git a/backends/transforms/duplicate_dynamic_quant_chain.py b/backends/transforms/duplicate_dynamic_quant_chain.py index 2ca65eec45f..6f75f14c188 100644 --- a/backends/transforms/duplicate_dynamic_quant_chain.py +++ b/backends/transforms/duplicate_dynamic_quant_chain.py @@ -9,14 +9,12 @@ import torch -from torch.ao.quantization.pt2e.utils import ( - _filter_sym_size_users, - _is_valid_annotation, -) - from torch.fx.node import map_arg from torch.fx.passes.infra.pass_base import PassBase, PassResult +from torchao.quantization.pt2e.quantizer import is_valid_annotation +from torchao.quantization.pt2e.utils import _filter_sym_size_users + logger = logging.getLogger(__name__) logger.setLevel(logging.WARNING) @@ -129,7 +127,7 @@ def _maybe_duplicate_dynamic_quantize_chain( dq_node_users = list(dq_node.users.copy()) for user in dq_node_users: annotation = user.meta.get("quantization_annotation", None) - if not _is_valid_annotation(annotation): + if not is_valid_annotation(annotation): return with gm.graph.inserting_after(dq_node): new_node = gm.graph.node_copy(dq_node) diff --git a/examples/qualcomm/oss_scripts/llama/llama.py b/examples/qualcomm/oss_scripts/llama/llama.py index e4b3c579ae2..c97aadc79a9 100755 --- a/examples/qualcomm/oss_scripts/llama/llama.py +++ b/examples/qualcomm/oss_scripts/llama/llama.py @@ -81,7 +81,7 @@ from pytorch_tokenizers import get_tokenizer, TiktokenTokenizer from pytorch_tokenizers.llama2c import Llama2cTokenizer as SentencePieceTokenizer -from torch.ao.quantization.observer import MinMaxObserver +from torchao.quantization.pt2e import MinMaxObserver from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e sys.setrecursionlimit(4096) diff --git a/examples/qualcomm/oss_scripts/moshi/mimi.py b/examples/qualcomm/oss_scripts/moshi/mimi.py index 6b59a71ae64..70e339a32d6 100644 --- a/examples/qualcomm/oss_scripts/moshi/mimi.py +++ b/examples/qualcomm/oss_scripts/moshi/mimi.py @@ -37,7 +37,7 @@ from huggingface_hub import hf_hub_download from moshi.models import loaders -from torch.ao.quantization.observer import MinMaxObserver +from torchao.quantization.pt2e import MinMaxObserver def seed_all(seed): diff --git a/examples/qualcomm/utils.py b/examples/qualcomm/utils.py index 856195d9211..526e376d148 100755 --- a/examples/qualcomm/utils.py +++ b/examples/qualcomm/utils.py @@ -19,6 +19,7 @@ import numpy as np import torch +import torchao from executorch.backends.qualcomm.quantizer.quantizer import ( ModuleQConfig, QnnQuantizer, @@ -33,7 +34,7 @@ ) from executorch.exir.capture._config import ExecutorchBackendConfig from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass -from torch.ao.quantization.observer import MovingAverageMinMaxObserver +from torchao.quantization.pt2e import MovingAverageMinMaxObserver from torchao.quantization.pt2e.quantize_pt2e import ( convert_pt2e, prepare_pt2e, @@ -229,7 +230,7 @@ def ptq_calibrate(captured_model, quantizer, dataset): def qat_train(ori_model, captured_model, quantizer, dataset): data, targets = dataset - annotated_model = torch.ao.quantization.move_exported_model_to_train( + annotated_model = torchao.quantization.pt2e.move_exported_model_to_train( prepare_qat_pt2e(captured_model, quantizer) ) optimizer = torch.optim.SGD(annotated_model.parameters(), lr=0.00001) @@ -238,7 +239,9 @@ def qat_train(ori_model, captured_model, quantizer, dataset): print(f"Epoch {i}") if i > 3: # Freeze quantizer parameters - annotated_model.apply(torch.ao.quantization.disable_observer) + annotated_model.apply( + torchao.quantization.pt2e.fake_quantize.disable_observer + ) if i > 2: # Freeze batch norm mean and variance estimates annotated_model.apply(torch.nn.intrinsic.qat.freeze_bn_stats) @@ -250,7 +253,7 @@ def qat_train(ori_model, captured_model, quantizer, dataset): optimizer.step() return convert_pt2e( - torch.ao.quantization.move_exported_model_to_eval(annotated_model), + torchao.quantization.pt2e.move_exported_model_to_eval(annotated_model), ) diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index feae9f45861..d32b44246f6 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -13,7 +13,7 @@ import contextlib import logging from enum import Enum -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple, Union from unittest.mock import patch import torch @@ -35,11 +35,17 @@ from executorch.extension.llm.export.export_passes import RemoveRedundantTransposes from pytorch_tokenizers import get_tokenizer -from torch.ao.quantization.quantizer import Quantizer -from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer + +# TODO: remove these once pt2e migration from torch.ao to torchao is complete +from torch.ao.quantization.quantizer import Quantizer as TorchQuantizer +from torch.ao.quantization.quantizer.composable_quantizer import ( + ComposableQuantizer as TorchComposableQuantizer, +) + from torch.export import export_for_training, ExportedProgram from torch.nn.attention import SDPBackend from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantizer import ComposableQuantizer, Quantizer from torchao.utils import unwrap_tensor_subclass FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" @@ -350,7 +356,9 @@ def calibrate_template( print(f"{task}: {res}") logging.info("Calibration finish...") - def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManager": + def pt2e_quantize( + self, quantizers: Optional[List[Union[Quantizer, TorchQuantizer]]] + ) -> "LLMEdgeManager": """ Quantize the model via pt2e flow and retrieve LLMEdgeManager including the quantized model. Args: @@ -367,7 +375,16 @@ def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManage with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(): if self.verbose: logging.info(f"Applied quantizers: {quantizers}") - composed_quantizer = ComposableQuantizer(quantizers) + + if all(isinstance(q, Quantizer) for q in quantizers): + composed_quantizer = ComposableQuantizer(quantizers) + elif all(isinstance(q, TorchQuantizer) for q in quantizers): + composed_quantizer = TorchComposableQuantizer(quantizers) + else: + raise ValueError( + "Quantizers must be either Quantizer or TorchQuantizer" + ) + assert ( self.pre_autograd_graph_module is not None ), "Please run export() first" diff --git a/extension/llm/export/quantizer_lib.py b/extension/llm/export/quantizer_lib.py index d7b8b3a92b1..b199ab57ccb 100644 --- a/extension/llm/export/quantizer_lib.py +++ b/extension/llm/export/quantizer_lib.py @@ -154,7 +154,7 @@ def get_qnn_quantizer( QnnQuantizer, QuantDtype, ) - from torch.ao.quantization.observer import MinMaxObserver + from torchao.quantization.pt2e import MinMaxObserver except ImportError: raise ImportError(