From e3eb1d92d8e26db37a0c06e40b71d744b7a5fc63 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 27 Sep 2023 12:56:20 -0700 Subject: [PATCH] [quant][docs] Add documentation for `prepare_pt2e`, `prepare_qat_pt2e` and `convert_pt2e` (#110097) Summary: att Test Plan: . Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/110097 Approved by: https://github.com/kimishpatel --- torch/ao/quantization/quantize_fx.py | 2 +- torch/ao/quantization/quantize_pt2e.py | 134 ++++++++++++++++++++++++- 2 files changed, 130 insertions(+), 6 deletions(-) diff --git a/torch/ao/quantization/quantize_fx.py b/torch/ao/quantization/quantize_fx.py index 9dca46aa4d72a..3fe676afec1be 100644 --- a/torch/ao/quantization/quantize_fx.py +++ b/torch/ao/quantization/quantize_fx.py @@ -247,7 +247,7 @@ def prepare_fx( _equalization_config: Optional[Union[QConfigMapping, Dict[str, Any]]] = None, backend_config: Union[BackendConfig, Dict[str, Any], None] = None, ) -> GraphModule: - r""" Prepare a model for post training static quantization + r""" Prepare a model for post training quantization Args: * `model` (torch.nn.Module): torch.nn.Module model diff --git a/torch/ao/quantization/quantize_pt2e.py b/torch/ao/quantization/quantize_pt2e.py index f5dbd5910aad6..765cb9446bfd1 100644 --- a/torch/ao/quantization/quantize_pt2e.py +++ b/torch/ao/quantization/quantize_pt2e.py @@ -39,6 +39,64 @@ def prepare_pt2e( model: GraphModule, quantizer: Quantizer, ) -> GraphModule: + """Prepare a model for post training quantization + + Args: + * `model` (torch.fx.GraphModule): a model captured by `torch.export` API + in the short term we are using `torch._export.capture_pre_autograd_graph`, + in the long term we'll migrate to some `torch.export` API + * `quantizer`: A backend specific quantizer that conveys how user want the + model to be quantized. Tutorial for how to write a quantizer can be found here: + https://pytorch.org/tutorials/prototype/pt2e_quantizer.html + + Return: + A GraphModule with observer (based on quantizer annotation), ready for calibration + + Example:: + + import torch + from torch.ao.quantization.quantize_pt2e import prepare_pt2e + from torch._export import capture_pre_autograd_graph + from torch.ao.quantization.quantizer import ( + XNNPACKQuantizer, + get_symmetric_quantization_config, + ) + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(5, 10) + + def forward(self, x): + return self.linear(x) + + # initialize a floating point model + float_model = M().eval() + + # define calibration function + def calibrate(model, data_loader): + model.eval() + with torch.no_grad(): + for image, target in data_loader: + model(image) + + # Step 1. program capture + # NOTE: this API will be updated to torch.export API in the future, but the captured + # result shoud mostly stay the same + m = capture_pre_autograd_graph(m, *example_inputs) + # we get a model with aten ops + + # Step 2. quantization + # backend developer will write their own Quantizer and expose methods to allow + # users to express how they + # want the model to be quantized + quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config()) + m = prepare_pt2e(m, quantizer) + + # run calibration + # calibrate(m, sample_inference_data) + """ + torch._C._log_api_usage_once("quantization_api.quantize_pt2e.prepare_pt2e") original_graph_meta = model.meta node_name_to_scope = _get_node_name_to_scope(model) # TODO: check qconfig_mapping to make sure conv and bn are both configured @@ -56,6 +114,60 @@ def prepare_qat_pt2e( model: GraphModule, quantizer: Quantizer, ) -> GraphModule: + """Prepare a model for quantization aware training + + Args: + * `model` (torch.fx.GraphModule): see :func:`~torch.ao.quantization.quantize_pt2e.prepare_pt2e` + * `quantizer`: see :func:`~torch.ao.quantization.quantize_pt2e.prepare_pt2e` + + Return: + A GraphModule with fake quant modules (based on quantizer annotation), ready for + quantization aware training + + Example:: + import torch + from torch.ao.quantization.quantize_pt2e import prepare_qat_pt2e + from torch._export import capture_pre_autograd_graph + from torch.ao.quantization.quantizer import ( + XNNPACKQuantizer, + get_symmetric_quantization_config, + ) + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(5, 10) + + def forward(self, x): + return self.linear(x) + + # initialize a floating point model + float_model = M().eval() + + # define the training loop for quantization aware training + def train_loop(model, train_data): + model.train() + for image, target in data_loader: + ... + + # Step 1. program capture + # NOTE: this API will be updated to torch.export API in the future, but the captured + # result shoud mostly stay the same + m = capture_pre_autograd_graph(m, *example_inputs) + # we get a model with aten ops + + # Step 2. quantization + # backend developer will write their own Quantizer and expose methods to allow + # users to express how they + # want the model to be quantized + quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config()) + m = prepare_qat_pt2e(m, quantizer) + + # run quantization aware training + train_loop(prepared_model, train_loop) + + """ + torch._C._log_api_usage_once("quantization_api.quantize_pt2e.prepare_qat_pt2e") original_graph_meta = model.meta node_name_to_scope = _get_node_name_to_scope(model) quantizer.annotate(model) @@ -92,17 +204,29 @@ def convert_pt2e( """Convert a calibrated/trained model to a quantized model Args: - model: calibrated/trained model - use_reference_representation: boolean flag to indicate whether to produce referece representation or not - fold_quantize: boolean flag to indicate whether fold the quantize op or not + * `model` (torch.fx.GraphModule): calibrated/trained model + * `use_reference_representation` (bool): boolean flag to indicate whether to produce referece representation or not + * `fold_quantize` (bool): boolean flag to indicate whether fold the quantize op or not Note: please set `fold_quantize` to True whenever you can, we'll deprecate this flag and make True the default option in the future, to make sure the change doesn't break BC for you, it's better to set the flag to True now. Returns: - quantized model, either in q/dq representation or reference representation - """ + quantized model, either in q/dq representation or reference representation + + Example:: + + # prepared_model: the model produced by `prepare_pt2e`/`prepare_qat_pt2e` and calibration/training + # `convert_pt2e` produces a quantized model that represents quantized computation with + # quantize dequantize ops and fp32 ops by default. + # Please refer to + # https://pytorch.org/tutorials/prototype/pt2e_quant_ptq_static.html#convert-the-calibrated-model-to-a-quantized-model + # for detailed explanation of output quantized model + quantized_model = convert_pt2e(prepared_model) + + """ # flake8: noqa + torch._C._log_api_usage_once("quantization_api.quantize_pt2e.convert_pt2e") original_graph_meta = model.meta model = _convert_to_reference_decomposed_fx(model) model = _fold_conv_bn_qat(model)