Skip to content

Commit

Permalink
[Quant][docs] Move BackendConfig API specification
Browse files Browse the repository at this point in the history
Summary: This commit moves the API specification section of
the BackendConfig tutorial to the docstrings, which is a more
suitable place for this content. This change also reduces some
duplication. There is no new content added in this change.

Reviewers: jerryzh168, vkuzo

Subscribers: jerryzh168, vkuzo

ghstack-source-id: 50e1bc189d628b544776f387b9ede7de3c258e48
Pull Request resolved: #91999
  • Loading branch information
andrewor14 committed Jan 11, 2023
1 parent d7dc1c2 commit ecb56a7
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 16 deletions.
1 change: 1 addition & 0 deletions docs/source/quantization-support.rst
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ Quantization to work with this as well.
BackendConfig
BackendPatternConfig
DTypeConfig
DTypeWithConstraints
ObservationType

torch.ao.quantization.fx.custom_config
Expand Down
54 changes: 38 additions & 16 deletions torch/ao/quantization/backend_config/backend_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,10 @@ class DTypeWithConstraints:
@dataclass
class DTypeConfig:
"""
Config for the set of supported input/output activation, weight, and bias data types for the
patterns defined in :class:`~torch.ao.quantization.backend_config.BackendConfig`.
Config object that specifies the supported data types passed as arguments to
quantize ops in the reference model spec, for input and output activations,
weights, and biases. This object also optionally specifies constraints
associated with the data types.
Example usage::
Expand Down Expand Up @@ -353,9 +355,8 @@ def to_dict(self) -> Dict[str, Any]:

class BackendPatternConfig:
"""
Config for ops defined in :class:`~torch.ao.quantization.backend_config.BackendConfig`.
Config object that specifies quantization behavior for a given operator pattern.
For a detailed example usage, see :class:`~torch.ao.quantization.backend_config.BackendConfig`.
"""
def __init__(self, pattern: Optional[Pattern] = None):
self.pattern: Optional[Pattern] = pattern
Expand Down Expand Up @@ -401,39 +402,55 @@ def set_pattern(self, pattern: Pattern) -> BackendPatternConfig:
def set_observation_type(self, observation_type: ObservationType) -> BackendPatternConfig:
"""
Set how observers should be inserted in the graph for this pattern.
Observation type here refers to how observers (or quant-dequant ops) will be placed
in the graph. This is used to produce the desired reference patterns understood by
the backend. Weighted ops such as linear and conv require different observers
(or quantization parameters passed to quantize ops in the reference model) for the
input and the output.
There are two observation types:
`OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT` (default): the output observer instance will be
different from the input. This is the most common observation type.
`OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT` (default): the output observer instance
will be different from the input. This is the most common observation type.
`OUTPUT_SHARE_OBSERVER_WITH_INPUT`: the output observer instance will be the same as the input.
This is useful for operators like `cat`.
`OUTPUT_SHARE_OBSERVER_WITH_INPUT`: the output observer instance will be the
same as the input. This is useful for operators like `cat`.
Note: This will be renamed in the near future, since we will soon insert QuantDeQuantStubs with
observers (and fake quantizes) attached instead of observers themselves.
Note: This will be renamed in the near future, since we will soon insert QuantDeQuantStubs
with observers (and fake quantizes) attached instead of observers themselves.
"""
self.observation_type = observation_type
return self

def add_dtype_config(self, dtype_config: DTypeConfig) -> BackendPatternConfig:
"""
Add a set of supported input/output activation, weight, and bias data types for this pattern.
Add a set of supported data types passed as arguments to quantize ops in the
reference model spec.
"""
self.dtype_configs.append(dtype_config)
return self

def set_dtype_configs(self, dtype_configs: List[DTypeConfig]) -> BackendPatternConfig:
"""
Set the supported input/output activation, weight, and bias data types for this pattern,
overriding all previously registered data types.
Set the supported data types passed as arguments to quantize ops in the
reference model spec, overriding all previously registered data types.
"""
self.dtype_configs = dtype_configs
return self

def set_root_module(self, root_module: Type[torch.nn.Module]) -> BackendPatternConfig:
"""
Set the module that represents the root for this pattern.
For example, the root module for :class:`torch.nn.intrinsic.LinearReLU` should be :class:`torch.nn.Linear`.
When we construct the reference quantized model during the convert phase,
the root modules (e.g. torch.nn.Linear for torch.ao.nn.intrinsic.LinearReLU)
will be swapped to the corresponding reference quantized modules (e.g.
torch.ao.nn.reference.quantized.Linear). This allows custom backends to
specify custom reference quantized module implementations to match the
numerics of their lowered operators. Since this is a one-to-one mapping,
both the root module and the reference quantized module must be specified
in the same BackendPatternConfig in order for the conversion to take place.
"""
self.root_module = root_module
return self
Expand All @@ -447,7 +464,10 @@ def set_qat_module(self, qat_module: Type[torch.nn.Module]) -> BackendPatternCon

def set_reference_quantized_module(self, reference_quantized_module: Type[torch.nn.Module]) -> BackendPatternConfig:
"""
Set the module that represents the reference quantized implementation for this pattern's root module.
Set the module that represents the reference quantized implementation for
this pattern's root module.
For more detail, see :func:`~torch.ao.quantization.backend_config.BackendPatternConfig.set_root_module`.
"""
self.reference_quantized_module = reference_quantized_module
return self
Expand All @@ -461,7 +481,7 @@ def set_fused_module(self, fused_module: Type[torch.nn.Module]) -> BackendPatter

def set_fuser_method(self, fuser_method: Callable) -> BackendPatternConfig:
"""
Set the function that specifies how to fuse the pattern for this pattern.
Set the function that specifies how to fuse this BackendPatternConfig's pattern.
The first argument of this function should be `is_qat`, and the rest of the arguments
should be the items in the tuple pattern. The return value of this function should be
Expand All @@ -471,6 +491,8 @@ def set_fuser_method(self, fuser_method: Callable) -> BackendPatternConfig:
def fuse_linear_relu(is_qat, linear, relu):
return torch.ao.nn.intrinsic.LinearReLU(linear, relu)
For a more complicated example, see https://gist.github.com/jerryzh168/8bea7180a8ba3c279f2c9b050f2a69a6.
"""
self.fuser_method = fuser_method
return self
Expand Down

0 comments on commit ecb56a7

Please sign in to comment.