Skip to content

Commit

Permalink
[quant][refactor] Remove register api and rename get_*_mapping to get…
Browse files Browse the repository at this point in the history
…_default_*_mapping

Summary:
We plan to pass around the mappings instead of using global registration api to keep
the mappings local to the transformations user is performing

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
  • Loading branch information
jerryzh168 committed Oct 14, 2020
1 parent 37c6be0 commit ce57dad
Show file tree
Hide file tree
Showing 9 changed files with 74 additions and 120 deletions.
17 changes: 7 additions & 10 deletions torch/quantization/__init__.py
Expand Up @@ -30,16 +30,13 @@ def default_eval_fn(model, calib_data):
'prepare_fx', 'prepare_dynamic_fx', 'convert_fx',
'QuantType', # quantization type
# custom module APIs
'register_static_quant_module_mapping',
'get_static_quant_module_mappings', 'get_static_quant_module_class',
'register_dynamic_quant_module_mapping',
'get_dynamic_quant_module_mappings',
'register_qat_module_mapping',
'get_qat_module_mappings',
'get_qconfig_propagation_list',
'get_compare_output_module_list',
'register_quantized_operator_mapping', 'get_quantized_operator',
'register_fuser_method', 'get_fuser_method',
'get_default_static_quant_module_mappings', 'get_static_quant_module_class',
'get_default_dynamic_quant_module_mappings',
'get_default_qat_module_mappings',
'get_default_qconfig_propagation_list',
'get_default_compare_output_module_list',
'get_quantized_operator',
'get_fuser_method',
# Sub functions for `prepare` and `swap_module`
'propagate_qconfig_', 'add_quant_dequant', 'add_observer_', 'swap_module',
'default_eval_fn', 'get_observer_dict',
Expand Down
6 changes: 3 additions & 3 deletions torch/quantization/_numeric_suite.py
Expand Up @@ -6,7 +6,7 @@
from torch.quantization import prepare

from .quantization_mappings import (
get_compare_output_module_list,
get_default_compare_output_module_list,
)

NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST = {
Expand Down Expand Up @@ -405,7 +405,7 @@ def prepare_model_outputs(
allow_list: list of module types to attach logger
"""
if allow_list is None:
allow_list = get_compare_output_module_list()
allow_list = get_default_compare_output_module_list()

qconfig_debug = torch.quantization.QConfig(activation=Logger, weight=None)
float_module.qconfig = qconfig_debug
Expand Down Expand Up @@ -451,7 +451,7 @@ def compare_model_outputs(
containing the matching float and quantized activations
"""
if allow_list is None:
allow_list = get_compare_output_module_list()
allow_list = get_default_compare_output_module_list()
prepare_model_outputs(float_model, q_model, Logger, allow_list)
float_model(*data)
q_model(*data)
Expand Down
12 changes: 3 additions & 9 deletions torch/quantization/fuser_method_mappings.py
Expand Up @@ -72,7 +72,7 @@ def fuse_conv_bn_relu(conv, bn, relu):
else:
raise NotImplementedError("Cannot fuse eval modules: {}".format((conv, bn, relu)))

OP_LIST_TO_FUSER_METHOD : Dict[Tuple, Union[nn.Sequential, Callable]] = {
DEFAULT_OP_LIST_TO_FUSER_METHOD : Dict[Tuple, Union[nn.Sequential, Callable]] = {
(nn.Conv1d, nn.BatchNorm1d): fuse_conv_bn,
(nn.Conv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
(nn.Conv2d, nn.BatchNorm2d): fuse_conv_bn,
Expand All @@ -87,15 +87,9 @@ def fuse_conv_bn_relu(conv, bn, relu):
(nn.BatchNorm3d, nn.ReLU): nni.BNReLU3d,
}

def register_fuser_method(op_list, fuser_method):
''' Register a fuser method for a tuple of ops, will be called
during fusion step
'''
assert isinstance(op_list, tuple), 'op list must be a tuple'
OP_LIST_TO_FUSER_METHOD[op_list] = fuser_method

# TODO: remove
def get_fuser_method(op_list):
''' Get fuser method for the given list of module types,
return None if fuser method does not exist
'''
return OP_LIST_TO_FUSER_METHOD.get(op_list, None)
return DEFAULT_OP_LIST_TO_FUSER_METHOD.get(op_list, None)
4 changes: 2 additions & 2 deletions torch/quantization/fx/fuse.py
Expand Up @@ -7,7 +7,7 @@

from .pattern_utils import (
is_match,
get_fusion_patterns,
get_default_fusion_patterns,
)

from .fusion_patterns import * # noqa: F401
Expand All @@ -21,7 +21,7 @@ def fuse(self, model, inplace=False):
input_graph = model.graph
self.modules = dict(input_root.named_modules())

fusion_patterns = get_fusion_patterns()
fusion_patterns = get_default_fusion_patterns()
# find fusion
fusion_pairs = self._find_matches(input_root, input_graph, fusion_patterns)
self.fused_graph = Graph()
Expand Down
16 changes: 8 additions & 8 deletions torch/quantization/fx/pattern_utils.py
Expand Up @@ -3,27 +3,27 @@
from collections import OrderedDict

# pattern for conv bn fusion
FUSION_PATTERNS = OrderedDict()
DEFAULT_FUSION_PATTERNS = OrderedDict()
def register_fusion_pattern(pattern):
def insert(fn):
FUSION_PATTERNS[pattern] = fn
DEFAULT_FUSION_PATTERNS[pattern] = fn
return fn
return insert

def get_fusion_patterns():
return FUSION_PATTERNS
def get_default_fusion_patterns():
return DEFAULT_FUSION_PATTERNS

QUANTIZATION_PATTERNS = OrderedDict()
DEFAULT_QUANTIZATION_PATTERNS = OrderedDict()
# Register pattern for both static quantization and qat
def register_quant_pattern(pattern):
def insert(fn):
QUANTIZATION_PATTERNS[pattern] = fn
DEFAULT_QUANTIZATION_PATTERNS[pattern] = fn
return fn
return insert

# Get patterns for both static quantization and qat
def get_quant_patterns():
return QUANTIZATION_PATTERNS
def get_default_quant_patterns():
return DEFAULT_QUANTIZATION_PATTERNS

# Example use of register pattern function:
# @register_fusion_pattern(torch.nn.ReLU, (torch.nn.BatchNorm2d, torch.nn.Conv2d)))
Expand Down
8 changes: 4 additions & 4 deletions torch/quantization/fx/quantize.py
Expand Up @@ -16,14 +16,14 @@
)

from ..quantization_mappings import (
get_qat_module_mappings,
get_default_qat_module_mappings,
)

from ..quantize import _remove_qconfig

from .pattern_utils import (
is_match,
get_quant_patterns,
get_default_quant_patterns,
)

from .standalone_module import (
Expand Down Expand Up @@ -231,7 +231,7 @@ def __init__(self):


def _qat_swap_modules(self, root):
convert(root, mapping=get_qat_module_mappings(), inplace=True, remove_qconfig=False)
convert(root, mapping=get_default_qat_module_mappings(), inplace=True, remove_qconfig=False)

def _generate_qconfig_map(self,
root,
Expand Down Expand Up @@ -324,7 +324,7 @@ def _prepare(self, model, qconfig_dict, inplace, prepare_custom_config_dict, is_
prepare_custom_config_dict = {}
if not inplace:
model = copy.deepcopy(model)
self.patterns = get_quant_patterns()
self.patterns = get_default_quant_patterns()

flattened_qconfig_dict = get_flattened_qconfig_dict(qconfig_dict)
# TODO: support regex as well
Expand Down
97 changes: 30 additions & 67 deletions torch/quantization/quantization_mappings.py
Expand Up @@ -11,8 +11,8 @@

from .stubs import QuantStub, DeQuantStub

# Map for swapping float module to quantized ones
STATIC_QUANT_MODULE_MAPPINGS = {
# Default map for swapping float module to quantized ones
DEFAULT_STATIC_QUANT_MODULE_MAPPINGS = {
QuantStub: nnq.Quantize,
DeQuantStub: nnq.DeQuantize,
nn.BatchNorm2d: nnq.BatchNorm2d,
Expand Down Expand Up @@ -53,8 +53,8 @@
nnqat.Conv2d: nnq.Conv2d,
}

# Map for swapping float module to qat modules
QAT_MODULE_MAPPINGS = {
# Default map for swapping float module to qat modules
DEFAULT_QAT_MODULE_MAPPINGS = {
nn.Conv2d: nnqat.Conv2d,
nn.Linear: nnqat.Linear,
# Intrinsic modules:
Expand All @@ -64,7 +64,7 @@
nni.LinearReLU: nniqat.LinearReLU
}

# Map for swapping dynamic modules
# Default map for swapping dynamic modules
DYNAMIC_QUANT_MODULE_MAPPINGS = {
nn.GRUCell: nnqd.GRUCell,
nn.Linear: nnqd.Linear,
Expand All @@ -81,111 +81,74 @@
nn.Sequential,
}

# mapping from floating point function or torch ops to quantized ops
FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS = {
# Default mapping from floating point function or torch ops to quantized ops
# TODO: merge with default static mapping
DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS = {
F.elu: torch._ops.ops.quantized.elu,
F.hardswish: torch._ops.ops.quantized.hardswish,
F.instance_norm: torch._ops.ops.quantized.instance_norm,
F.layer_norm: torch._ops.ops.quantized.layer_norm,
F.leaky_relu: torch._ops.ops.quantized.leaky_relu,
}

def register_static_quant_module_mapping(
float_source_module_class, static_quant_target_module_class):
''' Register a mapping from `float_source__module_class` to `static_quant_target_module_class`
`static_quant_target_module_class` must have from_float defined as a class method
The mapping is used in the convert step of post training static quantization to
convert a float module to a statically quantized module.
'''
assert hasattr(static_quant_target_module_class, 'from_float'), 'from_float must be defined' + \
' in quantized module class'
STATIC_QUANT_MODULE_MAPPINGS[float_source_module_class] = static_quant_target_module_class

def get_static_quant_module_mappings():
def get_default_static_quant_module_mappings():
''' Get module mapping for post training static quantization
'''
return STATIC_QUANT_MODULE_MAPPINGS
return DEFAULT_STATIC_QUANT_MODULE_MAPPINGS

def get_static_quant_module_class(float_module_class):
''' Get the statically quantized module class corresponding to
the floating point module class
'''
static_quant_module_class = STATIC_QUANT_MODULE_MAPPINGS.get(float_module_class, None)
static_quant_module_class = DEFAULT_STATIC_QUANT_MODULE_MAPPINGS.get(float_module_class, None)
assert static_quant_module_class is not None, \
'Floating point module class {}'.format(float_module_class) + \
' does not have a corresponding quantized module class'
return static_quant_module_class

def register_qat_module_mapping(float_source_module_class, qat_target_module_class):
'''Register a mapping from `float_source_module_class` to `qat_target_module_class`,
`qat_target_module_class` must have from_float defined as a class method
This mapping is used in prepare step of quantization aware training to swap
a float module to a qat module.
'''
assert hasattr(qat_target_module_class, 'from_float'), 'from_float must be defined' + \
' in qat module class'
QAT_MODULE_MAPPINGS[float_source_module_class] = qat_target_module_class

def get_qat_module_mappings():
''' Get module mapping for quantization aware training
def get_default_qat_module_mappings():
''' Get default module mapping for quantization aware training
'''
return QAT_MODULE_MAPPINGS
return DEFAULT_QAT_MODULE_MAPPINGS

def register_dynamic_quant_module_class(float_source_module_class, dynamic_quant_target_module_class):
''' Register a mapping from `float_source_module_class` to `dynamic_quant_target_module_class`,
`dynamic_quant_target_module_class` must have from_float defined as a class method
This mapping is used in convert step of post training dynamic
quantization to swap a float module to a dynamically quantized
module.
'''
assert hasattr(dynamic_quant_target_module_class, 'from_float'), 'from_float must be defined' + \
' in dynamically quantized module type'
DYNAMIC_QUANT_MODULE_MAPPINGS[float_source_module_class] = dynamic_quant_target_module_class

def get_dynamic_quant_module_mappings():
def get_default_dynamic_quant_module_mappings():
''' Get module mapping for post training dynamic quantization
'''
return DYNAMIC_QUANT_MODULE_MAPPINGS
return DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS

def get_qconfig_propagation_list():
''' Get the list of module types that we'll attach qconfig
def get_default_qconfig_propagation_list():
''' Get the default list of module types that we'll attach qconfig
attribute to in prepare
'''
QCONFIG_PROPAGATE_MODULE_CLASS_LIST = (
(set(STATIC_QUANT_MODULE_MAPPINGS.keys()) |
set(QAT_MODULE_MAPPINGS.keys()) |
set(DYNAMIC_QUANT_MODULE_MAPPINGS.keys()) |
(set(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS.keys()) |
set(DEFAULT_QAT_MODULE_MAPPINGS.keys()) |
set(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS.keys()) |
_INCLUDE_QCONFIG_PROPAGATE_LIST) -
_EXCLUDE_QCONFIG_PROPAGATE_LIST
)
return QCONFIG_PROPAGATE_MODULE_CLASS_LIST

def get_compare_output_module_list():
def get_default_compare_output_module_list():
''' Get list of module class types that we will record output
in numeric suite
'''
NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_MODULE_LIST = (
set(STATIC_QUANT_MODULE_MAPPINGS.values())
| set(QAT_MODULE_MAPPINGS.values())
| set(DYNAMIC_QUANT_MODULE_MAPPINGS.values())
| set(STATIC_QUANT_MODULE_MAPPINGS.keys())
| set(QAT_MODULE_MAPPINGS.keys())
| set(DYNAMIC_QUANT_MODULE_MAPPINGS.keys())
set(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS.values())
| set(DEFAULT_QAT_MODULE_MAPPINGS.values())
| set(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS.values())
| set(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS.keys())
| set(DEFAULT_QAT_MODULE_MAPPINGS.keys())
| set(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS.keys())
| _INCLUDE_QCONFIG_PROPAGATE_LIST
) - _EXCLUDE_QCONFIG_PROPAGATE_LIST
return NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_MODULE_LIST

def register_quantized_operator_mapping(float_op, quantized_op):
''' Register a mapping from `floating_point_op` (torch or functional) to `quantized_op`
This is used in convert step of fx based graph mode quantization
to convert a float op to quantized op.
'''
FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS[float_op] = quantized_op

# TODO: merge with get_static_quant_module_class
def get_quantized_operator(float_op):
''' Get the quantized operator corresponding to the float operator
'''
quantized_op = FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS.get(float_op, None)
quantized_op = DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS.get(float_op, None)
assert quantized_op is not None, \
'Operator {} does not have corresponding quantized op'.format(float_op)
return quantized_op

0 comments on commit ce57dad

Please sign in to comment.