From 8ff97cc35ebfaf60635c9bd03be9c2bf2f94ea06 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Mon, 9 Jun 2025 09:43:24 -0700 Subject: [PATCH] Migrate xnnpack/vulkan/boltnn pt2e from torch.ao to torchao (#11363) Summary: X-link: https://github.com/pytorch/ao/pull/2302 Pull Request resolved: https://github.com/pytorch/executorch/pull/11363 Reviewed By: jerryzh168, billmguo Differential Revision: D75492104 --- .lintrunner.toml | 8 +- backends/vulkan/quantizer/vulkan_quantizer.py | 9 +- backends/vulkan/test/test_vulkan_delegate.py | 4 +- backends/vulkan/test/test_vulkan_passes.py | 2 +- .../partition/config/quant_affine_configs.py | 21 +-- .../xnnpack/quantizer/xnnpack_quantizer.py | 25 ++-- .../quantizer/xnnpack_quantizer_utils.py | 137 ++++-------------- .../test/quantizer/test_pt2e_quantization.py | 82 +++++++---- .../test/quantizer/test_representation.py | 2 +- .../test/quantizer/test_xnnpack_quantizer.py | 9 +- backends/xnnpack/test/tester/tester.py | 2 +- docs/source/backends-xnnpack.md | 3 +- exir/tests/test_passes.py | 7 +- extension/llm/export/builder.py | 22 +-- extension/llm/export/quantizer_lib.py | 4 +- 15 files changed, 120 insertions(+), 217 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index b2603de8323..8912e65d66d 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -386,15 +386,9 @@ exclude_patterns = [ "third-party/**", # TODO: remove exceptions as we migrate # backends - "backends/vulkan/quantizer/**", - "backends/vulkan/test/**", - "backends/xnnpack/quantizer/**", - "backends/xnnpack/test/**", - "exir/tests/test_passes.py", - "extension/llm/export/builder.py", - "extension/llm/export/quantizer_lib.py", "exir/tests/test_memory_planning.py", "exir/backend/test/demos/test_xnnpack_qnnpack.py", + "backends/xnnpack/test/test_xnnpack_utils.py", ] command = [ diff --git a/backends/vulkan/quantizer/vulkan_quantizer.py b/backends/vulkan/quantizer/vulkan_quantizer.py index b2f1a658040..a82c2091cf6 100644 --- a/backends/vulkan/quantizer/vulkan_quantizer.py +++ b/backends/vulkan/quantizer/vulkan_quantizer.py @@ -16,11 +16,14 @@ _convert_scalars_to_attrs, OP_TO_ANNOTATOR, propagate_annotation, - QuantizationConfig, ) -from torch.ao.quantization.observer import PerChannelMinMaxObserver -from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer from torch.fx import Node +from torchao.quantization.pt2e import PerChannelMinMaxObserver +from torchao.quantization.pt2e.quantizer import ( + QuantizationConfig, + QuantizationSpec, + Quantizer, +) __all__ = [ diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index 447e5d039f4..0a81c66a0ad 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -23,12 +23,12 @@ EdgeProgramManager, ExecutorchProgramManager, ) - -from torch.ao.quantization.quantizer import Quantizer from torch.export import Dim, export, export_for_training, ExportedProgram from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantizer import Quantizer + ctypes.CDLL("libvulkan.so.1") diff --git a/backends/vulkan/test/test_vulkan_passes.py b/backends/vulkan/test/test_vulkan_passes.py index 66156e88155..ff9e2d85a96 100644 --- a/backends/vulkan/test/test_vulkan_passes.py +++ b/backends/vulkan/test/test_vulkan_passes.py @@ -16,9 +16,9 @@ from executorch.exir.backend.canonical_partitioners.config_partitioner import ( format_target_name, ) -from torch.ao.quantization.quantizer import Quantizer from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantizer import Quantizer ################### ## Common Models ## diff --git a/backends/xnnpack/partition/config/quant_affine_configs.py b/backends/xnnpack/partition/config/quant_affine_configs.py index 046402800a3..f7f59d6e096 100644 --- a/backends/xnnpack/partition/config/quant_affine_configs.py +++ b/backends/xnnpack/partition/config/quant_affine_configs.py @@ -33,33 +33,24 @@ class QuantizeAffineConfig(QDQAffineConfigs): target_name = "quantize_affine.default" def get_original_aten(self) -> Optional[torch._ops.OpOverload]: - try: - import torchao.quantization.quant_primitives # noqa + import torchao.quantization.quant_primitives # noqa - return torch.ops.torchao.quantize_affine.default - except: - return None + return torch.ops.torchao.quantize_affine.default class DeQuantizeAffineConfig(QDQAffineConfigs): target_name = "dequantize_affine.default" def get_original_aten(self) -> Optional[torch._ops.OpOverload]: - try: - import torchao.quantization.quant_primitives # noqa + import torchao.quantization.quant_primitives # noqa - return torch.ops.torchao.dequantize_affine.default - except: - return None + return torch.ops.torchao.dequantize_affine.default class ChooseQParamsAffineConfig(QDQAffineConfigs): target_name = "choose_qparams_affine.default" def get_original_aten(self) -> Optional[torch._ops.OpOverload]: - try: - import torchao.quantization.quant_primitives # noqa + import torchao.quantization.quant_primitives # noqa - return torch.ops.torchao.choose_qparams_affine.default - except: - return None + return torch.ops.torchao.choose_qparams_affine.default diff --git a/backends/xnnpack/quantizer/xnnpack_quantizer.py b/backends/xnnpack/quantizer/xnnpack_quantizer.py index 229b75f0ed9..130eda03f88 100644 --- a/backends/xnnpack/quantizer/xnnpack_quantizer.py +++ b/backends/xnnpack/quantizer/xnnpack_quantizer.py @@ -12,16 +12,11 @@ from executorch.backends.xnnpack.quantizer.xnnpack_quantizer_utils import ( _convert_scalars_to_attrs, OP_TO_ANNOTATOR, - OperatorConfig, - OperatorPatternType, propagate_annotation, - QuantizationConfig, ) -from torch.ao.quantization.fake_quantize import ( +from torchao.quantization.pt2e import ( FakeQuantize, FusedMovingAvgObsFakeQuantize, -) -from torch.ao.quantization.observer import ( HistogramObserver, MinMaxObserver, MovingAverageMinMaxObserver, @@ -29,13 +24,19 @@ PerChannelMinMaxObserver, PlaceholderObserver, ) -from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer -from torch.ao.quantization.quantizer.utils import _get_module_name_filter +from torchao.quantization.pt2e.quantizer import ( + get_module_name_filter, + OperatorConfig, + OperatorPatternType, + QuantizationConfig, + QuantizationSpec, + Quantizer, +) if TYPE_CHECKING: - from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor from torch.fx import Node + from torchao.quantization.pt2e import ObserverOrFakeQuantizeConstructor __all__ = [ @@ -140,7 +141,7 @@ def get_symmetric_quantization_config( weight_qscheme = ( torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric ) - weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = ( + weight_observer_or_fake_quant_ctr: ObserverOrFakeQuantizeConstructor = ( MinMaxObserver ) if is_qat: @@ -228,7 +229,7 @@ def _get_not_module_type_or_name_filter( tp_list: list[Callable], module_name_list: list[str] ) -> Callable[[Node], bool]: module_type_filters = [_get_module_type_filter(tp) for tp in tp_list] - module_name_list_filters = [_get_module_name_filter(m) for m in module_name_list] + module_name_list_filters = [get_module_name_filter(m) for m in module_name_list] def not_module_type_or_name_filter(n: Node) -> bool: return not any(f(n) for f in module_type_filters + module_name_list_filters) @@ -421,7 +422,7 @@ def _annotate_for_quantization_config( module_name_list = list(self.module_name_config.keys()) for module_name, config in self.module_name_config.items(): self._annotate_all_patterns( - model, config, _get_module_name_filter(module_name) + model, config, get_module_name_filter(module_name) ) tp_list = list(self.module_type_config.keys()) diff --git a/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py b/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py index 4b961bef81d..0dcfb4484ed 100644 --- a/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py +++ b/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py @@ -1,39 +1,43 @@ # mypy: allow-untyped-defs import itertools -import typing -from dataclasses import dataclass -from typing import Callable, NamedTuple, Optional +from typing import Callable, Optional import torch import torch.nn.functional as F from executorch.backends.xnnpack.utils.utils import is_depthwise_conv from torch._subclasses import FakeTensor -from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix -from torch.ao.quantization.pt2e.export_utils import _WrapperModule -from torch.ao.quantization.pt2e.utils import ( - _get_aten_graph_module_for_pattern, - _is_conv_node, - _is_conv_transpose_node, +from torch.fx import Node +from torch.fx.passes.utils.matcher_with_name_node_map_utils import ( + SubgraphMatcherWithNameNodeMap, ) -from torch.ao.quantization.quantizer import ( +from torchao.quantization.pt2e import WrapperModule +from torchao.quantization.pt2e.graph_utils import get_source_partitions +from torchao.quantization.pt2e.quantizer import ( + annotate_input_qspec_map, + annotate_output_qspec, + get_bias_qspec, + get_input_act_qspec, + get_output_act_qspec, + get_weight_qspec, + OperatorConfig, + OperatorPatternType, QuantizationAnnotation, + QuantizationConfig, QuantizationSpec, SharedQuantizationSpec, ) -from torch.ao.quantization.quantizer.utils import ( - _annotate_input_qspec_map, - _annotate_output_qspec, -) -from torch.fx import Node -from torch.fx.passes.utils.matcher_with_name_node_map_utils import ( - SubgraphMatcherWithNameNodeMap, +from torchao.quantization.pt2e.utils import ( + _get_aten_graph_module_for_pattern, + _is_conv_node, + _is_conv_transpose_node, + get_new_attr_name_with_prefix, ) -from torch.fx.passes.utils.source_matcher_utils import get_source_partitions __all__ = [ "OperatorConfig", "OperatorPatternType", "QuantizationConfig", + "QuantizationSpec", "get_input_act_qspec", "get_output_act_qspec", "get_weight_qspec", @@ -43,23 +47,6 @@ ] -# In the absence of better name, just winging it with QuantizationConfig -@dataclass(eq=True, frozen=True) -class QuantizationConfig: - input_activation: Optional[QuantizationSpec] - output_activation: Optional[QuantizationSpec] - weight: Optional[QuantizationSpec] - bias: Optional[QuantizationSpec] - # TODO: remove, since we can use observer_or_fake_quant_ctr to express this - is_qat: bool = False - - -# Use Annotated because list[Callable].__module__ is read-only. -OperatorPatternType = typing.Annotated[list[Callable], None] -OperatorPatternType.__module__ = ( - "executorch.backends.xnnpack.quantizer.xnnpack_quantizer_utils" -) - AnnotatorType = Callable[ [ torch.fx.GraphModule, @@ -78,19 +65,6 @@ def decorator(annotator: AnnotatorType) -> None: return decorator -class OperatorConfig(NamedTuple): - # fix List[str] with List[List[Union[nn.Module, FunctionType, BuiltinFunctionType]]] - # Basically we are mapping a quantization config to some list of patterns. - # a pattern is defined as a list of nn module, function or builtin function names - # e.g. [nn.Conv2d, torch.relu, torch.add] - # We have not resolved whether fusion can be considered internal details of the - # quantizer hence it does not need communication to user. - # Note this pattern is not really informative since it does not really - # tell us the graph structure resulting from the list of ops. - config: QuantizationConfig - operators: list[OperatorPatternType] - - def is_relu_node(node: Node) -> bool: """ Check if a given node is a relu node @@ -124,63 +98,6 @@ def _mark_nodes_as_annotated(nodes: list[Node]): node.meta["quantization_annotation"]._annotated = True -def get_input_act_qspec(quantization_config: Optional[QuantizationConfig]): - if quantization_config is None: - return None - if quantization_config.input_activation is None: - return None - quantization_spec: QuantizationSpec = quantization_config.input_activation - assert quantization_spec.qscheme in [ - torch.per_tensor_affine, - torch.per_tensor_symmetric, - ] - return quantization_spec - - -def get_output_act_qspec(quantization_config: Optional[QuantizationConfig]): - if quantization_config is None: - return None - if quantization_config.output_activation is None: - return None - quantization_spec: QuantizationSpec = quantization_config.output_activation - assert quantization_spec.qscheme in [ - torch.per_tensor_affine, - torch.per_tensor_symmetric, - ] - return quantization_spec - - -def get_weight_qspec(quantization_config: Optional[QuantizationConfig]): - if quantization_config is None: - return None - assert quantization_config is not None - if quantization_config.weight is None: - return None - quantization_spec: QuantizationSpec = quantization_config.weight - if quantization_spec.qscheme not in [ - torch.per_tensor_symmetric, - torch.per_channel_symmetric, - None, - ]: - raise ValueError( - f"Unsupported quantization_spec {quantization_spec} for weight" - ) - return quantization_spec - - -def get_bias_qspec(quantization_config: Optional[QuantizationConfig]): - if quantization_config is None: - return None - assert quantization_config is not None - if quantization_config.bias is None: - return None - quantization_spec: QuantizationSpec = quantization_config.bias - assert ( - quantization_spec.dtype == torch.float - ), "Only float dtype for bias is supported for bias right now" - return quantization_spec - - @register_annotator("linear") def _annotate_linear( gm: torch.fx.GraphModule, @@ -204,25 +121,25 @@ def _annotate_linear( bias_node = node.args[2] if _is_annotated([node]) is False: # type: ignore[list-item] - _annotate_input_qspec_map( + annotate_input_qspec_map( node, act_node, input_act_qspec, ) - _annotate_input_qspec_map( + annotate_input_qspec_map( node, weight_node, weight_qspec, ) nodes_to_mark_annotated = [node, weight_node] if bias_node: - _annotate_input_qspec_map( + annotate_input_qspec_map( node, bias_node, bias_qspec, ) nodes_to_mark_annotated.append(bias_node) - _annotate_output_qspec(node, output_act_qspec) + annotate_output_qspec(node, output_act_qspec) _mark_nodes_as_annotated(nodes_to_mark_annotated) annotated_partitions.append(nodes_to_mark_annotated) @@ -572,7 +489,7 @@ def _conv_bn(x, conv_weight, conv_bias, bn_weight, bn_bias, bn_rm, bn_rv): "output": output, } - return _WrapperModule(_conv_bn) + return WrapperModule(_conv_bn) # Needed for matching, otherwise the matches gets filtered out due to unused # nodes returned by batch norm diff --git a/backends/xnnpack/test/quantizer/test_pt2e_quantization.py b/backends/xnnpack/test/quantizer/test_pt2e_quantization.py index 5f8682f621c..5053458613e 100644 --- a/backends/xnnpack/test/quantizer/test_pt2e_quantization.py +++ b/backends/xnnpack/test/quantizer/test_pt2e_quantization.py @@ -14,17 +14,7 @@ get_symmetric_quantization_config, XNNPACKQuantizer, ) -from torch.ao.quantization import ( - compare_results, - CUSTOM_KEY, - default_per_channel_symmetric_qnnpack_qconfig, - extract_results_from_loggers, - generate_numeric_debug_handle, - NUMERIC_DEBUG_HANDLE_KEY, - observer, - prepare_for_propagation_comparison, -) -from torch.ao.quantization.pt2e.graph_utils import bfs_trace_with_node_process +from torch.ao.quantization import default_per_channel_symmetric_qnnpack_qconfig from torch.ao.quantization.qconfig import ( float_qparams_weight_only_qconfig, per_channel_weight_observer_range_neg_127_to_127, @@ -32,28 +22,52 @@ weight_observer_range_neg_127_to_127, ) from torch.ao.quantization.qconfig_mapping import QConfigMapping -from torch.ao.quantization.quantizer import Quantizer -from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer -from torch.ao.quantization.quantizer.embedding_quantizer import EmbeddingQuantizer from torch.export import export_for_training from torch.testing._internal.common_quantization import ( NodeSpec as ns, - PT2EQuantizationTestCase, TestHelperModules, ) + from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, TemporaryFileName, TestCase, ) +from torchao.quantization.pt2e import ( + allow_exported_model_train_eval, + compare_results, + CUSTOM_KEY, + extract_results_from_loggers, + generate_numeric_debug_handle, + NUMERIC_DEBUG_HANDLE_KEY, + prepare_for_propagation_comparison, +) + +from torchao.quantization.pt2e.graph_utils import bfs_trace_with_node_process from torchao.quantization.pt2e.quantize_pt2e import ( convert_pt2e, prepare_pt2e, prepare_qat_pt2e, ) +from torchao.quantization.pt2e.quantizer import ComposableQuantizer, Quantizer +from torchao.quantization.pt2e.quantizer.embedding_quantizer import EmbeddingQuantizer +from torchao.testing.pt2e.utils import PT2EQuantizationTestCase class TestQuantizePT2E(PT2EQuantizationTestCase): + def _quantize(self, m, quantizer, example_inputs, is_qat: bool = False): + # resetting dynamo cache + torch._dynamo.reset() + + m = export_for_training(m, example_inputs, strict=True).module() + if is_qat: + m = prepare_qat_pt2e(m, quantizer) + else: + m = prepare_pt2e(m, quantizer) + m(*example_inputs) + m = convert_pt2e(m) + return m + def _get_pt2e_quantized_linear( self, is_per_channel: bool = False ) -> torch.fx.GraphModule: @@ -198,13 +212,15 @@ def test_composable_quantizer_linear_conv(self) -> None: torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, } - act_affine_quant_obs = observer.PlaceholderObserver.with_args( - dtype=torch.qint8, - qscheme=torch.per_tensor_affine, - quant_min=-128, - quant_max=127, - eps=2**-12, - is_dynamic=True, + act_affine_quant_obs = ( + torch.ao.quantization.observer.PlaceholderObserver.with_args( + dtype=torch.qint8, + qscheme=torch.per_tensor_affine, + quant_min=-128, + quant_max=127, + eps=2**-12, + is_dynamic=True, + ) ) dynamic_qconfig = QConfig( activation=act_affine_quant_obs, @@ -287,13 +303,15 @@ def test_embedding_conv_linear_quantization(self) -> None: [embedding_quantizer, dynamic_quantizer, static_quantizer] ) - act_affine_quant_obs = observer.PlaceholderObserver.with_args( - dtype=torch.qint8, - qscheme=torch.per_tensor_affine, - quant_min=-128, - quant_max=127, - eps=2**-12, - is_dynamic=True, + act_affine_quant_obs = ( + torch.ao.quantization.observer.PlaceholderObserver.with_args( + dtype=torch.qint8, + qscheme=torch.per_tensor_affine, + quant_min=-128, + quant_max=127, + eps=2**-12, + is_dynamic=True, + ) ) dynamic_qconfig = QConfig( activation=act_affine_quant_obs, @@ -404,7 +422,7 @@ def _assert_ops_are_correct(m: torch.fx.GraphModule, train: bool) -> None: m.train() # After wrapping: does not error and swaps the ops accordingly - torch.ao.quantization.allow_exported_model_train_eval(m) # pyre-ignore[6] + allow_exported_model_train_eval(m) # pyre-ignore[6] m.eval() _assert_ops_are_correct(m, train=False) # pyre-ignore[6] m.train() @@ -419,7 +437,7 @@ def _assert_ops_are_correct(m: torch.fx.GraphModule, train: bool) -> None: m.train() # After prepare and after wrapping: does not error and swaps the ops accordingly - torch.ao.quantization.allow_exported_model_train_eval(m) + allow_exported_model_train_eval(m) m.eval() _assert_ops_are_correct(m, train=False) m.train() @@ -433,7 +451,7 @@ def _assert_ops_are_correct(m: torch.fx.GraphModule, train: bool) -> None: m.train() # After convert and after wrapping: does not error and swaps the ops accordingly - torch.ao.quantization.allow_exported_model_train_eval(m) + allow_exported_model_train_eval(m) m.eval() _assert_ops_are_correct(m, train=False) m.train() diff --git a/backends/xnnpack/test/quantizer/test_representation.py b/backends/xnnpack/test/quantizer/test_representation.py index de9a9cb14ea..817f7f9e368 100644 --- a/backends/xnnpack/test/quantizer/test_representation.py +++ b/backends/xnnpack/test/quantizer/test_representation.py @@ -8,7 +8,6 @@ XNNPACKQuantizer, ) from torch._higher_order_ops.out_dtype import out_dtype # noqa: F401 -from torch.ao.quantization.quantizer import Quantizer from torch.export import export_for_training from torch.testing._internal.common_quantization import ( NodeSpec as ns, @@ -17,6 +16,7 @@ TestHelperModules, ) from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantizer import Quantizer @skipIfNoQNNPACK diff --git a/backends/xnnpack/test/quantizer/test_xnnpack_quantizer.py b/backends/xnnpack/test/quantizer/test_xnnpack_quantizer.py index ae787059860..0a317ad8822 100644 --- a/backends/xnnpack/test/quantizer/test_xnnpack_quantizer.py +++ b/backends/xnnpack/test/quantizer/test_xnnpack_quantizer.py @@ -12,7 +12,6 @@ from torch.ao.quantization import ( default_dynamic_fake_quant, default_dynamic_qconfig, - observer, QConfig, QConfigMapping, ) @@ -31,13 +30,13 @@ from torch.export import export_for_training from torch.testing._internal.common_quantization import ( NodeSpec as ns, - PT2EQuantizationTestCase, skip_if_no_torchvision, skipIfNoQNNPACK, TestHelperModules, ) from torch.testing._internal.common_quantized import override_quantized_engine from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.testing.pt2e.utils import PT2EQuantizationTestCase @skipIfNoQNNPACK @@ -575,7 +574,7 @@ def test_dynamic_linear(self): torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, } - act_affine_quant_obs = observer.PlaceholderObserver.with_args( + act_affine_quant_obs = torch.ao.quantization.PlaceholderObserver.with_args( dtype=torch.qint8, qscheme=torch.per_tensor_affine, quant_min=-128, @@ -621,7 +620,7 @@ def test_dynamic_linear_int4_weight(self): torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, } - act_affine_quant_obs = observer.PlaceholderObserver.with_args( + act_affine_quant_obs = torch.ao.quantization.PlaceholderObserver.with_args( dtype=torch.qint8, qscheme=torch.per_tensor_affine, quant_min=-128, @@ -718,7 +717,7 @@ def test_dynamic_linear_with_conv(self): torch.ops.quantized_decomposed.quantize_per_tensor.default: 0, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 0, } - act_affine_quant_obs = observer.PlaceholderObserver.with_args( + act_affine_quant_obs = torch.ao.quantization.PlaceholderObserver.with_args( dtype=torch.qint8, qscheme=torch.per_tensor_affine, quant_min=-128, diff --git a/backends/xnnpack/test/tester/tester.py b/backends/xnnpack/test/tester/tester.py index 381bd1efa4d..dad0d5ad0e0 100644 --- a/backends/xnnpack/test/tester/tester.py +++ b/backends/xnnpack/test/tester/tester.py @@ -56,7 +56,6 @@ ) from executorch.exir.program._program import _transform from torch._export.pass_base import PassType -from torch.ao.quantization.quantizer.quantizer import Quantizer from torch.export import export, ExportedProgram from torch.testing import FileCheck from torch.utils._pytree import tree_flatten @@ -65,6 +64,7 @@ prepare_pt2e, prepare_qat_pt2e, ) +from torchao.quantization.pt2e.quantizer import Quantizer class Stage(ABC): diff --git a/docs/source/backends-xnnpack.md b/docs/source/backends-xnnpack.md index e47e1115cc3..46ab379f186 100644 --- a/docs/source/backends-xnnpack.md +++ b/docs/source/backends-xnnpack.md @@ -91,11 +91,10 @@ The output of `convert_pt2e` is a PyTorch model which can be exported and lowere import torch import torchvision.models as models from torchvision.models.mobilenetv2 import MobileNet_V2_Weights -from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import XNNPACKQuantizer +from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import XNNPACKQuantizer, get_symmetric_quantization_config from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner from executorch.exir import to_edge_transform_and_lower from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e -from torch.ao.quantization.quantizer.xnnpack_quantizer import get_symmetric_quantization_config model = models.mobilenetv2.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).eval() sample_inputs = (torch.randn(1, 3, 224, 224), ) diff --git a/exir/tests/test_passes.py b/exir/tests/test_passes.py index ca2b5ebdc35..dd4037b64c0 100644 --- a/exir/tests/test_passes.py +++ b/exir/tests/test_passes.py @@ -71,7 +71,6 @@ from functorch.experimental import control_flow from torch import nn -from torch.ao.quantization.quantizer import QuantizationSpec from torch.export import export from torch.export.graph_signature import InputKind, InputSpec, TensorArgument from torch.fx import GraphModule, subgraph_rewriter @@ -81,6 +80,7 @@ from torch.utils import _pytree as pytree from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantizer import QuantizationSpec # pyre-ignore @@ -1196,10 +1196,7 @@ def forward(self, query, key, value): ).module() # 8w16a quantization - from torch.ao.quantization.observer import ( - MinMaxObserver, - PerChannelMinMaxObserver, - ) + from torchao.quantization.pt2e import MinMaxObserver, PerChannelMinMaxObserver activation_qspec = QuantizationSpec( dtype=torch.int16, diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index d32b44246f6..8b81587c434 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -13,7 +13,7 @@ import contextlib import logging from enum import Enum -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple from unittest.mock import patch import torch @@ -35,13 +35,6 @@ from executorch.extension.llm.export.export_passes import RemoveRedundantTransposes from pytorch_tokenizers import get_tokenizer - -# TODO: remove these once pt2e migration from torch.ao to torchao is complete -from torch.ao.quantization.quantizer import Quantizer as TorchQuantizer -from torch.ao.quantization.quantizer.composable_quantizer import ( - ComposableQuantizer as TorchComposableQuantizer, -) - from torch.export import export_for_training, ExportedProgram from torch.nn.attention import SDPBackend from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e @@ -356,9 +349,7 @@ def calibrate_template( print(f"{task}: {res}") logging.info("Calibration finish...") - def pt2e_quantize( - self, quantizers: Optional[List[Union[Quantizer, TorchQuantizer]]] - ) -> "LLMEdgeManager": + def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManager": """ Quantize the model via pt2e flow and retrieve LLMEdgeManager including the quantized model. Args: @@ -376,14 +367,7 @@ def pt2e_quantize( if self.verbose: logging.info(f"Applied quantizers: {quantizers}") - if all(isinstance(q, Quantizer) for q in quantizers): - composed_quantizer = ComposableQuantizer(quantizers) - elif all(isinstance(q, TorchQuantizer) for q in quantizers): - composed_quantizer = TorchComposableQuantizer(quantizers) - else: - raise ValueError( - "Quantizers must be either Quantizer or TorchQuantizer" - ) + composed_quantizer = ComposableQuantizer(quantizers) assert ( self.pre_autograd_graph_module is not None diff --git a/extension/llm/export/quantizer_lib.py b/extension/llm/export/quantizer_lib.py index b199ab57ccb..99499e34bb2 100644 --- a/extension/llm/export/quantizer_lib.py +++ b/extension/llm/export/quantizer_lib.py @@ -16,8 +16,8 @@ XNNPACKQuantizer, ) -from torch.ao.quantization.quantizer import Quantizer -from torch.ao.quantization.quantizer.embedding_quantizer import EmbeddingQuantizer +from torchao.quantization.pt2e.quantizer import Quantizer +from torchao.quantization.pt2e.quantizer.embedding_quantizer import EmbeddingQuantizer FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT)