diff --git a/backends/cadence/aot/export_example.py b/backends/cadence/aot/export_example.py index 0204f717fb7..092bbc4b192 100644 --- a/backends/cadence/aot/export_example.py +++ b/backends/cadence/aot/export_example.py @@ -9,6 +9,8 @@ import logging import tempfile +import torch + from executorch.backends.cadence.aot.ops_registrations import * # noqa from typing import Any, Tuple @@ -17,11 +19,17 @@ export_to_cadence_edge_executorch, fuse_pt2, ) + from executorch.backends.cadence.aot.quantizer.quantizer import CadenceQuantizer from executorch.backends.cadence.runtime import runtime from executorch.backends.cadence.runtime.executor import BundledProgramManager from executorch.exir import ExecutorchProgramManager from torch import nn +from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver +from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( + QuantizationConfig, + QuantizationSpec, +) from .utils import save_bpte_program, save_pte_program @@ -29,6 +37,24 @@ FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) +act_qspec = QuantizationSpec( + dtype=torch.int8, + quant_min=-128, + quant_max=127, + qscheme=torch.per_tensor_affine, + is_dynamic=False, + observer_or_fake_quant_ctr=HistogramObserver.with_args(eps=2**-12), +) + +wgt_qspec = QuantizationSpec( + dtype=torch.int8, + quant_min=-128, + quant_max=127, + qscheme=torch.per_tensor_affine, + is_dynamic=False, + observer_or_fake_quant_ctr=MinMaxObserver, +) + def export_model( model: nn.Module, @@ -39,8 +65,15 @@ def export_model( working_dir = tempfile.mkdtemp(dir="/tmp") logging.debug(f"Created work directory {working_dir}") + qconfig = QuantizationConfig( + act_qspec, + act_qspec, + wgt_qspec, + None, + ) + # Instantiate the quantizer - quantizer = CadenceQuantizer() + quantizer = CadenceQuantizer(qconfig) # Convert the model converted_model = convert_pt2(model, example_inputs, quantizer) diff --git a/backends/cadence/aot/quantizer/quantizer.py b/backends/cadence/aot/quantizer/quantizer.py index 51bace91685..73ca40c9aa6 100644 --- a/backends/cadence/aot/quantizer/quantizer.py +++ b/backends/cadence/aot/quantizer/quantizer.py @@ -141,13 +141,20 @@ def get_supported_operators(cls) -> List[OperatorConfig]: class CadenceQuantizer(ComposableQuantizer): - def __init__(self) -> None: - static_qconfig = QuantizationConfig( - act_qspec, - act_qspec, - wgt_qspec, - None, + def __init__( + self, quantization_config: Optional[QuantizationConfig] = None + ) -> None: + static_qconfig = ( + QuantizationConfig( + act_qspec, + act_qspec, + wgt_qspec, + None, + ) + if not quantization_config + else quantization_config ) + super().__init__( [ CadenceAtenQuantizer(AddmmPattern(), static_qconfig),