Skip to content

Commit

Permalink
[quant][docs] Add documentation for prepare_pt2e, `prepare_qat_pt2e…
Browse files Browse the repository at this point in the history
…` and `convert_pt2e` (#110097)

Summary:
att

Test Plan:
.

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: #110097
Approved by: https://github.com/kimishpatel
  • Loading branch information
jerryzh168 authored and pytorchmergebot committed Sep 28, 2023
1 parent 3603f64 commit e3eb1d9
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 6 deletions.
2 changes: 1 addition & 1 deletion torch/ao/quantization/quantize_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def prepare_fx(
_equalization_config: Optional[Union[QConfigMapping, Dict[str, Any]]] = None,
backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
) -> GraphModule:
r""" Prepare a model for post training static quantization
r""" Prepare a model for post training quantization
Args:
* `model` (torch.nn.Module): torch.nn.Module model
Expand Down
134 changes: 129 additions & 5 deletions torch/ao/quantization/quantize_pt2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,64 @@ def prepare_pt2e(
model: GraphModule,
quantizer: Quantizer,
) -> GraphModule:
"""Prepare a model for post training quantization
Args:
* `model` (torch.fx.GraphModule): a model captured by `torch.export` API
in the short term we are using `torch._export.capture_pre_autograd_graph`,
in the long term we'll migrate to some `torch.export` API
* `quantizer`: A backend specific quantizer that conveys how user want the
model to be quantized. Tutorial for how to write a quantizer can be found here:
https://pytorch.org/tutorials/prototype/pt2e_quantizer.html
Return:
A GraphModule with observer (based on quantizer annotation), ready for calibration
Example::
import torch
from torch.ao.quantization.quantize_pt2e import prepare_pt2e
from torch._export import capture_pre_autograd_graph
from torch.ao.quantization.quantizer import (
XNNPACKQuantizer,
get_symmetric_quantization_config,
)
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(5, 10)
def forward(self, x):
return self.linear(x)
# initialize a floating point model
float_model = M().eval()
# define calibration function
def calibrate(model, data_loader):
model.eval()
with torch.no_grad():
for image, target in data_loader:
model(image)
# Step 1. program capture
# NOTE: this API will be updated to torch.export API in the future, but the captured
# result shoud mostly stay the same
m = capture_pre_autograd_graph(m, *example_inputs)
# we get a model with aten ops
# Step 2. quantization
# backend developer will write their own Quantizer and expose methods to allow
# users to express how they
# want the model to be quantized
quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config())
m = prepare_pt2e(m, quantizer)
# run calibration
# calibrate(m, sample_inference_data)
"""
torch._C._log_api_usage_once("quantization_api.quantize_pt2e.prepare_pt2e")
original_graph_meta = model.meta
node_name_to_scope = _get_node_name_to_scope(model)
# TODO: check qconfig_mapping to make sure conv and bn are both configured
Expand All @@ -56,6 +114,60 @@ def prepare_qat_pt2e(
model: GraphModule,
quantizer: Quantizer,
) -> GraphModule:
"""Prepare a model for quantization aware training
Args:
* `model` (torch.fx.GraphModule): see :func:`~torch.ao.quantization.quantize_pt2e.prepare_pt2e`
* `quantizer`: see :func:`~torch.ao.quantization.quantize_pt2e.prepare_pt2e`
Return:
A GraphModule with fake quant modules (based on quantizer annotation), ready for
quantization aware training
Example::
import torch
from torch.ao.quantization.quantize_pt2e import prepare_qat_pt2e
from torch._export import capture_pre_autograd_graph
from torch.ao.quantization.quantizer import (
XNNPACKQuantizer,
get_symmetric_quantization_config,
)
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(5, 10)
def forward(self, x):
return self.linear(x)
# initialize a floating point model
float_model = M().eval()
# define the training loop for quantization aware training
def train_loop(model, train_data):
model.train()
for image, target in data_loader:
...
# Step 1. program capture
# NOTE: this API will be updated to torch.export API in the future, but the captured
# result shoud mostly stay the same
m = capture_pre_autograd_graph(m, *example_inputs)
# we get a model with aten ops
# Step 2. quantization
# backend developer will write their own Quantizer and expose methods to allow
# users to express how they
# want the model to be quantized
quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config())
m = prepare_qat_pt2e(m, quantizer)
# run quantization aware training
train_loop(prepared_model, train_loop)
"""
torch._C._log_api_usage_once("quantization_api.quantize_pt2e.prepare_qat_pt2e")
original_graph_meta = model.meta
node_name_to_scope = _get_node_name_to_scope(model)
quantizer.annotate(model)
Expand Down Expand Up @@ -92,17 +204,29 @@ def convert_pt2e(
"""Convert a calibrated/trained model to a quantized model
Args:
model: calibrated/trained model
use_reference_representation: boolean flag to indicate whether to produce referece representation or not
fold_quantize: boolean flag to indicate whether fold the quantize op or not
* `model` (torch.fx.GraphModule): calibrated/trained model
* `use_reference_representation` (bool): boolean flag to indicate whether to produce referece representation or not
* `fold_quantize` (bool): boolean flag to indicate whether fold the quantize op or not
Note: please set `fold_quantize` to True whenever you can, we'll deprecate this flag and
make True the default option in the future, to make sure the change doesn't break BC for you, it's
better to set the flag to True now.
Returns:
quantized model, either in q/dq representation or reference representation
"""
quantized model, either in q/dq representation or reference representation
Example::
# prepared_model: the model produced by `prepare_pt2e`/`prepare_qat_pt2e` and calibration/training
# `convert_pt2e` produces a quantized model that represents quantized computation with
# quantize dequantize ops and fp32 ops by default.
# Please refer to
# https://pytorch.org/tutorials/prototype/pt2e_quant_ptq_static.html#convert-the-calibrated-model-to-a-quantized-model
# for detailed explanation of output quantized model
quantized_model = convert_pt2e(prepared_model)
""" # flake8: noqa
torch._C._log_api_usage_once("quantization_api.quantize_pt2e.convert_pt2e")
original_graph_meta = model.meta
model = _convert_to_reference_decomposed_fx(model)
model = _fold_conv_bn_qat(model)
Expand Down

0 comments on commit e3eb1d9

Please sign in to comment.