Skip to content

Commit

Permalink
[reland][quant][docs] Add fx graph mode quantization to quantization …
Browse files Browse the repository at this point in the history
…docs (#49211) (#49515)

Summary: Pull Request resolved: #49515

Test Plan:
Imported from OSS

Imported from OSS

Reviewed By: vkuzo

Differential Revision: D25601061

fbshipit-source-id: 74e917d57895e9b4131a01fdcea8df3e94322bec
  • Loading branch information
jerryzh168 authored and facebook-github-bot committed Dec 17, 2020
1 parent 815d383 commit b8d98f0
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 16 deletions.
4 changes: 2 additions & 2 deletions caffe2/contrib/aten/gen_op.py
Expand Up @@ -20,7 +20,7 @@
import argparse
import os
from copy import deepcopy
from typing import Dict, List
from typing import Dict, List, Set

parser = argparse.ArgumentParser()
parser.add_argument("--template_dir", default=".", help="where template.h is")
Expand Down Expand Up @@ -241,7 +241,7 @@ def emit_assignments(o, env):
'implementations': [],
'cases': [],
} # type: Dict[str, List]
seen = set()
seen: Set[str] = set()
key = 0
for o in filtered:
# [DESCRIPTORS]
Expand Down
5 changes: 1 addition & 4 deletions docs/source/quantization-support.rst
Expand Up @@ -105,7 +105,7 @@ Fused modules are provided for common patterns in CNNs. Combining several
operations together (like convolution and relu) allows for better quantization
accuracy


* `torch.nn.intrinsic` — float versions of the modules, can be swapped with
quantized version 1 to 1:

Expand Down Expand Up @@ -172,7 +172,6 @@ Layers for the quantization-aware training
* :func:`~torch.quantization.fuse_modules`

* Functions for graph mode quantization:

* :func:`~torch.quantization.quantize_jit` - Function for graph mode post training static quantization
* :func:`~torch.quantization.quantize_dynamic_jit` - Function for graph mode post training dynamic quantization

Expand Down Expand Up @@ -322,5 +321,3 @@ Quantized dtypes and quantization schemes
* :attr:`torch.quint8` — 8-bit unsigned integer
* :attr:`torch.qint8` — 8-bit signed integer
* :attr:`torch.qint32` — 32-bit signed integer


103 changes: 99 additions & 4 deletions docs/source/quantization.rst
Expand Up @@ -80,7 +80,16 @@ The corresponding implementation is chosen automatically based on the PyTorch bu
Quantization API Summary
---------------------------------------

There are three types of quantization supported in PyTorch:
PyTorch provides two different modes of quantization: Eager Mode Quantization and FX Graph Mode Quantization.

Eager Mode Quantization is a beta feature. User needs to do fusion and specify where quantization and dequantization happens manually, also it only supports modules and not functionals.

FX Graph Mode Quantization is a new automated quantization framework in PyTorch, and currently it's a prototype feature. It improves upon Eager Mode Quantization by adding support for functionals and automating the quantization process. Although people might need to refactor the model a bit to make the model compatible with FX Graph Mode Quantization (symbolically traceable with torch.fx).

Eager Mode Quantization
^^^^^^^^^^^^^^^^^^^^^^^

There are three types of quantization supported in Eager Mode Quantization:

1. dynamic quantization (weights quantized with activations read/stored in
floating point and quantized for compute.)
Expand All @@ -95,7 +104,7 @@ for a more comprehensive overview of the tradeoffs between these quantization
types.

Dynamic Quantization
^^^^^^^^^^^^^^^^^^^^
~~~~~~~~~~~~~~~~~~~~

This is the simplest to apply form of quantization where the weights are
quantized ahead of time but the activations are dynamically quantized
Expand Down Expand Up @@ -148,7 +157,7 @@ To learn more about dynamic quantization please see our `dynamic quantization tu
<https://pytorch.org/tutorials/recipes/recipes/dynamic_quantization.html>`_.

Static Quantization
^^^^^^^^^^^^^^^^^^^^
~~~~~~~~~~~~~~~~~~~

Static quantization quantizes the weights and activations of the model. It
fuses activations into preceding layers where possible. It requires
Expand Down Expand Up @@ -238,7 +247,7 @@ To learn more about static quantization, please see the `static quantization tut
<https://pytorch.org/tutorials/advanced/static_quantization_tutorial.html>`_.

Quantization Aware Training
^^^^^^^^^^^^^^^^^^^^^^^^^^^
~~~~~~~~~~~~~~~~~~~~~~~~~~~

Quantization Aware Training models the effects of quantization during training
allowing for higher accuracy compared to other quantization methods. During
Expand Down Expand Up @@ -332,6 +341,92 @@ To learn more about quantization aware training, please see the `QAT
tutorial
<https://pytorch.org/tutorials/advanced/static_quantization_tutorial.html>`_.

(Prototype) FX Graph Mode Quantization
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Quantization types supported by FX Graph Mode can be classified in two ways:

1.
- Post Training Quantization (apply quantization after training, quantization parameters are calculated based on sample calibration data)
- Quantization Aware Training (simulate quantization during training so that the quantization parameters can be learned together with the model using training data)

2.
- Weight Only Quantization (only weight is statically quantized)
- Dynamic Quantization (weight is statically quantized, activation is dynamically quantized)
- Static Quantization (both weight and activations are statically quantized)

These two ways of classification are independent, so theoretically we can have 6 different types of quantization.

The supported quantization types in FX Graph Mode Quantization are:
- Post Training Quantization

- Weight Only Quantization
- Dynamic Quantization
- Static Quantization

- Quantization Aware Training

- Static Quantization


There are multiple quantization types in post training quantization (weight only, dynamic and static) and the configuration is done through `qconfig_dict` (an argument of the `prepare_fx` function).

API Example::

import torch.quantization.quantize_fx as quantize_fx
import copy

model_fp = UserModel(...)

#
# post training dynamic/weight_only quantization
#

# we need to deepcopy if we still want to keep model_fp unchanged after quantization since quantization apis change the input model
model_to_quantize = copy.deepcopy(model_fp)
model_to_quantize.eval()
qconfig_dict = {"": torch.quantization.default_dynamic_qconfig}
# prepare
model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_dict)
# no calibration needed when we only have dynamici/weight_only quantization
# quantize
model_quantized = quantize_fx.convert_fx(model_prepared)

#
# post training static quantization
#

model_to_quantize = copy.deepcopy(model_fp)
qconfig_dict = {"": torch.quantization.get_default_qconfig('qnnpack')}
model_to_quantize.eval()
# prepare
model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_dict)
# calibrate (not shown)
# quantize
model_quantized = quantize_fx.convert_fx(model_prepared)

#
# quantization aware training for static quantization
#

model_to_quantize = copy.deepcopy(model_fp)
qconfig_dict = {"": torch.quantization.get_default_qat_qconfig('qnnpack')}
model_to_quantize.train()
# prepare
model_prepared = quantize_fx.prepare_qat_fx(model_to_qunatize, qconfig_dict)
# training loop (not shown)
# quantize
model_quantized = quantize_fx.convert_fx(model_prepared)

#
# fusion
#
model_to_quantize = copy.deepcopy(model_fp)
model_fused = quantize_fx.fuse_fx(model_to_quantize)

Please see the following tutorials for more information about FX Graph Mode Quantization:
- FX Graph Mode Post Training Static Quantization (TODO: link)
- FX Graph Mode Post Training Dynamic Quantization (TODO: link)

Quantized Tensors
---------------------------------------

Expand Down
2 changes: 1 addition & 1 deletion tools/codegen/gen.py
Expand Up @@ -705,7 +705,7 @@ def compute_native_function_declaration(g: Union[StructuredNativeFunctions, Nati
# only out has dispatch
meta_name = meta.name(g)
rs = []
seen = set()
seen: Set[Any] = set()
out_args = native.arguments(g.out.func)
for k, n in g.out.dispatch.items():
if n in seen:
Expand Down
11 changes: 6 additions & 5 deletions torch/utils/_pytree.py
Expand Up @@ -112,8 +112,8 @@ def tree_flatten(pytree: PyTree) -> Tuple[List[Any], TreeSpec]:
child_pytrees, context = flatten_fn(pytree)

# Recursively flatten the children
result = []
children_specs = []
result : List[Any] = []
children_specs : List['TreeSpec'] = []
for child in child_pytrees:
flat, child_spec = tree_flatten(child)
result += flat
Expand Down Expand Up @@ -178,11 +178,12 @@ def _broadcast_to_and_flatten(pytree: PyTree, spec: TreeSpec) -> Optional[List[A
return None

# Recursively flatten the children
result = []
result : List[Any] = []
for child, child_spec in zip(child_pytrees, spec.children_specs):
flat = _broadcast_to_and_flatten(child, child_spec)
if flat is None:
if flat is not None:
result += flat
else:
return None
result += flat

return result

0 comments on commit b8d98f0

Please sign in to comment.