Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[quant][refactor] Remove register api and rename get_*_mapping to get_default_*_mapping #46337

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
17 changes: 7 additions & 10 deletions torch/quantization/__init__.py
Expand Up @@ -30,16 +30,13 @@ def default_eval_fn(model, calib_data):
'prepare_fx', 'prepare_dynamic_fx', 'convert_fx',
'QuantType', # quantization type
# custom module APIs
'register_static_quant_module_mapping',
'get_static_quant_module_mappings', 'get_static_quant_module_class',
'register_dynamic_quant_module_mapping',
'get_dynamic_quant_module_mappings',
'register_qat_module_mapping',
'get_qat_module_mappings',
'get_qconfig_propagation_list',
'get_compare_output_module_list',
'register_quantized_operator_mapping', 'get_quantized_operator',
'register_fuser_method', 'get_fuser_method',
'get_default_static_quant_module_mappings', 'get_static_quant_module_class',
'get_default_dynamic_quant_module_mappings',
'get_default_qat_module_mappings',
'get_default_qconfig_propagation_list',
'get_default_compare_output_module_list',
'get_quantized_operator',
'get_fuser_method',
# Sub functions for `prepare` and `swap_module`
'propagate_qconfig_', 'add_quant_dequant', 'add_observer_', 'swap_module',
'default_eval_fn', 'get_observer_dict',
Expand Down
6 changes: 3 additions & 3 deletions torch/quantization/_numeric_suite.py
Expand Up @@ -7,7 +7,7 @@
from typing import Dict

from .quantization_mappings import (
get_compare_output_module_list,
get_default_compare_output_module_list,
)

NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST = {
Expand Down Expand Up @@ -412,7 +412,7 @@ def prepare_model_outputs(
"""
torch._C._log_api_usage_once("quantization_api._numeric_suite.prepare_model_outputs")
if allow_list is None:
allow_list = get_compare_output_module_list()
allow_list = get_default_compare_output_module_list()

qconfig_debug = torch.quantization.QConfig(activation=Logger, weight=None)
float_module.qconfig = qconfig_debug
Expand Down Expand Up @@ -459,7 +459,7 @@ def compare_model_outputs(
"""
torch._C._log_api_usage_once("quantization_api._numeric_suite.compare_model_outputs")
if allow_list is None:
allow_list = get_compare_output_module_list()
allow_list = get_default_compare_output_module_list()
prepare_model_outputs(float_model, q_model, Logger, allow_list)
float_model(*data)
q_model(*data)
Expand Down
12 changes: 3 additions & 9 deletions torch/quantization/fuser_method_mappings.py
Expand Up @@ -72,7 +72,7 @@ def fuse_conv_bn_relu(conv, bn, relu):
else:
raise NotImplementedError("Cannot fuse eval modules: {}".format((conv, bn, relu)))

OP_LIST_TO_FUSER_METHOD : Dict[Tuple, Union[nn.Sequential, Callable]] = {
DEFAULT_OP_LIST_TO_FUSER_METHOD : Dict[Tuple, Union[nn.Sequential, Callable]] = {
(nn.Conv1d, nn.BatchNorm1d): fuse_conv_bn,
(nn.Conv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
(nn.Conv2d, nn.BatchNorm2d): fuse_conv_bn,
Expand All @@ -87,15 +87,9 @@ def fuse_conv_bn_relu(conv, bn, relu):
(nn.BatchNorm3d, nn.ReLU): nni.BNReLU3d,
}

def register_fuser_method(op_list, fuser_method):
''' Register a fuser method for a tuple of ops, will be called
during fusion step
'''
assert isinstance(op_list, tuple), 'op list must be a tuple'
OP_LIST_TO_FUSER_METHOD[op_list] = fuser_method

# TODO: remove
def get_fuser_method(op_list):
''' Get fuser method for the given list of module types,
return None if fuser method does not exist
'''
return OP_LIST_TO_FUSER_METHOD.get(op_list, None)
return DEFAULT_OP_LIST_TO_FUSER_METHOD.get(op_list, None)
4 changes: 2 additions & 2 deletions torch/quantization/fx/fuse.py
Expand Up @@ -7,7 +7,7 @@

from .pattern_utils import (
is_match,
get_fusion_patterns,
get_default_fusion_patterns,
)

from .fusion_patterns import * # noqa: F401
Expand All @@ -21,7 +21,7 @@ def fuse(self, model, inplace=False):
input_graph = model.graph
self.modules = dict(input_root.named_modules())

fusion_patterns = get_fusion_patterns()
fusion_patterns = get_default_fusion_patterns()
# find fusion
fusion_pairs = self._find_matches(input_root, input_graph, fusion_patterns)
self.fused_graph = Graph()
Expand Down
16 changes: 8 additions & 8 deletions torch/quantization/fx/pattern_utils.py
Expand Up @@ -3,27 +3,27 @@
from collections import OrderedDict

# pattern for conv bn fusion
FUSION_PATTERNS = OrderedDict()
DEFAULT_FUSION_PATTERNS = OrderedDict()
def register_fusion_pattern(pattern):
def insert(fn):
FUSION_PATTERNS[pattern] = fn
DEFAULT_FUSION_PATTERNS[pattern] = fn
return fn
return insert

def get_fusion_patterns():
return FUSION_PATTERNS
def get_default_fusion_patterns():
return DEFAULT_FUSION_PATTERNS

QUANTIZATION_PATTERNS = OrderedDict()
DEFAULT_QUANTIZATION_PATTERNS = OrderedDict()
# Register pattern for both static quantization and qat
def register_quant_pattern(pattern):
def insert(fn):
QUANTIZATION_PATTERNS[pattern] = fn
DEFAULT_QUANTIZATION_PATTERNS[pattern] = fn
return fn
return insert

# Get patterns for both static quantization and qat
def get_quant_patterns():
return QUANTIZATION_PATTERNS
def get_default_quant_patterns():
return DEFAULT_QUANTIZATION_PATTERNS

# Example use of register pattern function:
# @register_fusion_pattern(torch.nn.ReLU, (torch.nn.BatchNorm2d, torch.nn.Conv2d)))
Expand Down
8 changes: 4 additions & 4 deletions torch/quantization/fx/quantize.py
Expand Up @@ -16,14 +16,14 @@
)

from ..quantization_mappings import (
get_qat_module_mappings,
get_default_qat_module_mappings,
)

from ..quantize import _remove_qconfig

from .pattern_utils import (
is_match,
get_quant_patterns,
get_default_quant_patterns,
)

from .standalone_module import (
Expand Down Expand Up @@ -231,7 +231,7 @@ def __init__(self):


def _qat_swap_modules(self, root):
convert(root, mapping=get_qat_module_mappings(), inplace=True, remove_qconfig=False)
convert(root, mapping=get_default_qat_module_mappings(), inplace=True, remove_qconfig=False)

def _generate_qconfig_map(self,
root,
Expand Down Expand Up @@ -324,7 +324,7 @@ def _prepare(self, model, qconfig_dict, inplace, prepare_custom_config_dict, is_
prepare_custom_config_dict = {}
if not inplace:
model = copy.deepcopy(model)
self.patterns = get_quant_patterns()
self.patterns = get_default_quant_patterns()

flattened_qconfig_dict = get_flattened_qconfig_dict(qconfig_dict)
# TODO: support regex as well
Expand Down
99 changes: 31 additions & 68 deletions torch/quantization/quantization_mappings.py
Expand Up @@ -11,8 +11,8 @@

from .stubs import QuantStub, DeQuantStub

# Map for swapping float module to quantized ones
STATIC_QUANT_MODULE_MAPPINGS = {
# Default map for swapping float module to quantized ones
DEFAULT_STATIC_QUANT_MODULE_MAPPINGS = {
QuantStub: nnq.Quantize,
DeQuantStub: nnq.DeQuantize,
nn.BatchNorm2d: nnq.BatchNorm2d,
Expand Down Expand Up @@ -53,8 +53,8 @@
nnqat.Conv2d: nnq.Conv2d,
}

# Map for swapping float module to qat modules
QAT_MODULE_MAPPINGS = {
# Default map for swapping float module to qat modules
DEFAULT_QAT_MODULE_MAPPINGS = {
nn.Conv2d: nnqat.Conv2d,
nn.Linear: nnqat.Linear,
# Intrinsic modules:
Expand All @@ -64,8 +64,8 @@
nni.LinearReLU: nniqat.LinearReLU
}

# Map for swapping dynamic modules
DYNAMIC_QUANT_MODULE_MAPPINGS = {
# Default map for swapping dynamic modules
DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS = {
nn.GRUCell: nnqd.GRUCell,
nn.Linear: nnqd.Linear,
nn.LSTM: nnqd.LSTM,
Expand All @@ -81,111 +81,74 @@
nn.Sequential,
}

# mapping from floating point function or torch ops to quantized ops
FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS = {
# Default mapping from floating point function or torch ops to quantized ops
# TODO: merge with default static mapping
DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS = {
F.elu: torch._ops.ops.quantized.elu,
F.hardswish: torch._ops.ops.quantized.hardswish,
F.instance_norm: torch._ops.ops.quantized.instance_norm,
F.layer_norm: torch._ops.ops.quantized.layer_norm,
F.leaky_relu: torch._ops.ops.quantized.leaky_relu,
}

def register_static_quant_module_mapping(
float_source_module_class, static_quant_target_module_class):
''' Register a mapping from `float_source__module_class` to `static_quant_target_module_class`
`static_quant_target_module_class` must have from_float defined as a class method
The mapping is used in the convert step of post training static quantization to
convert a float module to a statically quantized module.
'''
assert hasattr(static_quant_target_module_class, 'from_float'), 'from_float must be defined' + \
' in quantized module class'
STATIC_QUANT_MODULE_MAPPINGS[float_source_module_class] = static_quant_target_module_class

def get_static_quant_module_mappings():
def get_default_static_quant_module_mappings():
''' Get module mapping for post training static quantization
'''
return STATIC_QUANT_MODULE_MAPPINGS
return DEFAULT_STATIC_QUANT_MODULE_MAPPINGS

def get_static_quant_module_class(float_module_class):
''' Get the statically quantized module class corresponding to
the floating point module class
'''
static_quant_module_class = STATIC_QUANT_MODULE_MAPPINGS.get(float_module_class, None)
static_quant_module_class = DEFAULT_STATIC_QUANT_MODULE_MAPPINGS.get(float_module_class, None)
assert static_quant_module_class is not None, \
'Floating point module class {}'.format(float_module_class) + \
' does not have a corresponding quantized module class'
return static_quant_module_class

def register_qat_module_mapping(float_source_module_class, qat_target_module_class):
'''Register a mapping from `float_source_module_class` to `qat_target_module_class`,
`qat_target_module_class` must have from_float defined as a class method
This mapping is used in prepare step of quantization aware training to swap
a float module to a qat module.
'''
assert hasattr(qat_target_module_class, 'from_float'), 'from_float must be defined' + \
' in qat module class'
QAT_MODULE_MAPPINGS[float_source_module_class] = qat_target_module_class

def get_qat_module_mappings():
''' Get module mapping for quantization aware training
def get_default_qat_module_mappings():
''' Get default module mapping for quantization aware training
'''
return QAT_MODULE_MAPPINGS
return DEFAULT_QAT_MODULE_MAPPINGS

def register_dynamic_quant_module_class(float_source_module_class, dynamic_quant_target_module_class):
''' Register a mapping from `float_source_module_class` to `dynamic_quant_target_module_class`,
`dynamic_quant_target_module_class` must have from_float defined as a class method
This mapping is used in convert step of post training dynamic
quantization to swap a float module to a dynamically quantized
module.
'''
assert hasattr(dynamic_quant_target_module_class, 'from_float'), 'from_float must be defined' + \
' in dynamically quantized module type'
DYNAMIC_QUANT_MODULE_MAPPINGS[float_source_module_class] = dynamic_quant_target_module_class

def get_dynamic_quant_module_mappings():
def get_default_dynamic_quant_module_mappings():
''' Get module mapping for post training dynamic quantization
'''
return DYNAMIC_QUANT_MODULE_MAPPINGS
return DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS

def get_qconfig_propagation_list():
''' Get the list of module types that we'll attach qconfig
def get_default_qconfig_propagation_list():
''' Get the default list of module types that we'll attach qconfig
attribute to in prepare
'''
QCONFIG_PROPAGATE_MODULE_CLASS_LIST = (
(set(STATIC_QUANT_MODULE_MAPPINGS.keys()) |
set(QAT_MODULE_MAPPINGS.keys()) |
set(DYNAMIC_QUANT_MODULE_MAPPINGS.keys()) |
(set(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS.keys()) |
set(DEFAULT_QAT_MODULE_MAPPINGS.keys()) |
set(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS.keys()) |
_INCLUDE_QCONFIG_PROPAGATE_LIST) -
_EXCLUDE_QCONFIG_PROPAGATE_LIST
)
return QCONFIG_PROPAGATE_MODULE_CLASS_LIST

def get_compare_output_module_list():
def get_default_compare_output_module_list():
''' Get list of module class types that we will record output
in numeric suite
'''
NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_MODULE_LIST = (
set(STATIC_QUANT_MODULE_MAPPINGS.values())
| set(QAT_MODULE_MAPPINGS.values())
| set(DYNAMIC_QUANT_MODULE_MAPPINGS.values())
| set(STATIC_QUANT_MODULE_MAPPINGS.keys())
| set(QAT_MODULE_MAPPINGS.keys())
| set(DYNAMIC_QUANT_MODULE_MAPPINGS.keys())
set(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS.values())
| set(DEFAULT_QAT_MODULE_MAPPINGS.values())
| set(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS.values())
| set(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS.keys())
| set(DEFAULT_QAT_MODULE_MAPPINGS.keys())
| set(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS.keys())
| _INCLUDE_QCONFIG_PROPAGATE_LIST
) - _EXCLUDE_QCONFIG_PROPAGATE_LIST
return NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_MODULE_LIST

def register_quantized_operator_mapping(float_op, quantized_op):
''' Register a mapping from `floating_point_op` (torch or functional) to `quantized_op`
This is used in convert step of fx based graph mode quantization
to convert a float op to quantized op.
'''
FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS[float_op] = quantized_op

# TODO: merge with get_static_quant_module_class
def get_quantized_operator(float_op):
''' Get the quantized operator corresponding to the float operator
'''
quantized_op = FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS.get(float_op, None)
quantized_op = DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS.get(float_op, None)
assert quantized_op is not None, \
'Operator {} does not have corresponding quantized op'.format(float_op)
return quantized_op