From ce57dadb966876187459df657951ddf14e0d9e73 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 14 Oct 2020 13:49:49 -0700 Subject: [PATCH] [quant][refactor] Remove register api and rename get_*_mapping to get_default_*_mapping Summary: We plan to pass around the mappings instead of using global registration api to keep the mappings local to the transformations user is performing Test Plan: Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- torch/quantization/__init__.py | 17 ++-- torch/quantization/_numeric_suite.py | 6 +- torch/quantization/fuser_method_mappings.py | 12 +-- torch/quantization/fx/fuse.py | 4 +- torch/quantization/fx/pattern_utils.py | 16 +-- torch/quantization/fx/quantize.py | 8 +- torch/quantization/quantization_mappings.py | 97 ++++++------------- torch/quantization/quantize.py | 22 ++--- .../testing/_internal/common_quantization.py | 12 +-- 9 files changed, 74 insertions(+), 120 deletions(-) diff --git a/torch/quantization/__init__.py b/torch/quantization/__init__.py index ed908ddf85c3..f56c3821e872 100644 --- a/torch/quantization/__init__.py +++ b/torch/quantization/__init__.py @@ -30,16 +30,13 @@ def default_eval_fn(model, calib_data): 'prepare_fx', 'prepare_dynamic_fx', 'convert_fx', 'QuantType', # quantization type # custom module APIs - 'register_static_quant_module_mapping', - 'get_static_quant_module_mappings', 'get_static_quant_module_class', - 'register_dynamic_quant_module_mapping', - 'get_dynamic_quant_module_mappings', - 'register_qat_module_mapping', - 'get_qat_module_mappings', - 'get_qconfig_propagation_list', - 'get_compare_output_module_list', - 'register_quantized_operator_mapping', 'get_quantized_operator', - 'register_fuser_method', 'get_fuser_method', + 'get_default_static_quant_module_mappings', 'get_static_quant_module_class', + 'get_default_dynamic_quant_module_mappings', + 'get_default_qat_module_mappings', + 'get_default_qconfig_propagation_list', + 'get_default_compare_output_module_list', + 'get_quantized_operator', + 'get_fuser_method', # Sub functions for `prepare` and `swap_module` 'propagate_qconfig_', 'add_quant_dequant', 'add_observer_', 'swap_module', 'default_eval_fn', 'get_observer_dict', diff --git a/torch/quantization/_numeric_suite.py b/torch/quantization/_numeric_suite.py index 01703269604a..79f1276051c2 100644 --- a/torch/quantization/_numeric_suite.py +++ b/torch/quantization/_numeric_suite.py @@ -6,7 +6,7 @@ from torch.quantization import prepare from .quantization_mappings import ( - get_compare_output_module_list, + get_default_compare_output_module_list, ) NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST = { @@ -405,7 +405,7 @@ def prepare_model_outputs( allow_list: list of module types to attach logger """ if allow_list is None: - allow_list = get_compare_output_module_list() + allow_list = get_default_compare_output_module_list() qconfig_debug = torch.quantization.QConfig(activation=Logger, weight=None) float_module.qconfig = qconfig_debug @@ -451,7 +451,7 @@ def compare_model_outputs( containing the matching float and quantized activations """ if allow_list is None: - allow_list = get_compare_output_module_list() + allow_list = get_default_compare_output_module_list() prepare_model_outputs(float_model, q_model, Logger, allow_list) float_model(*data) q_model(*data) diff --git a/torch/quantization/fuser_method_mappings.py b/torch/quantization/fuser_method_mappings.py index 72ad5a7bcc71..63acc849f56a 100644 --- a/torch/quantization/fuser_method_mappings.py +++ b/torch/quantization/fuser_method_mappings.py @@ -72,7 +72,7 @@ def fuse_conv_bn_relu(conv, bn, relu): else: raise NotImplementedError("Cannot fuse eval modules: {}".format((conv, bn, relu))) -OP_LIST_TO_FUSER_METHOD : Dict[Tuple, Union[nn.Sequential, Callable]] = { +DEFAULT_OP_LIST_TO_FUSER_METHOD : Dict[Tuple, Union[nn.Sequential, Callable]] = { (nn.Conv1d, nn.BatchNorm1d): fuse_conv_bn, (nn.Conv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu, (nn.Conv2d, nn.BatchNorm2d): fuse_conv_bn, @@ -87,15 +87,9 @@ def fuse_conv_bn_relu(conv, bn, relu): (nn.BatchNorm3d, nn.ReLU): nni.BNReLU3d, } -def register_fuser_method(op_list, fuser_method): - ''' Register a fuser method for a tuple of ops, will be called - during fusion step - ''' - assert isinstance(op_list, tuple), 'op list must be a tuple' - OP_LIST_TO_FUSER_METHOD[op_list] = fuser_method - +# TODO: remove def get_fuser_method(op_list): ''' Get fuser method for the given list of module types, return None if fuser method does not exist ''' - return OP_LIST_TO_FUSER_METHOD.get(op_list, None) + return DEFAULT_OP_LIST_TO_FUSER_METHOD.get(op_list, None) diff --git a/torch/quantization/fx/fuse.py b/torch/quantization/fx/fuse.py index 852de812e39d..ecabaf1eaf7e 100644 --- a/torch/quantization/fx/fuse.py +++ b/torch/quantization/fx/fuse.py @@ -7,7 +7,7 @@ from .pattern_utils import ( is_match, - get_fusion_patterns, + get_default_fusion_patterns, ) from .fusion_patterns import * # noqa: F401 @@ -21,7 +21,7 @@ def fuse(self, model, inplace=False): input_graph = model.graph self.modules = dict(input_root.named_modules()) - fusion_patterns = get_fusion_patterns() + fusion_patterns = get_default_fusion_patterns() # find fusion fusion_pairs = self._find_matches(input_root, input_graph, fusion_patterns) self.fused_graph = Graph() diff --git a/torch/quantization/fx/pattern_utils.py b/torch/quantization/fx/pattern_utils.py index 2984da0b80fd..efc1cd492d1b 100644 --- a/torch/quantization/fx/pattern_utils.py +++ b/torch/quantization/fx/pattern_utils.py @@ -3,27 +3,27 @@ from collections import OrderedDict # pattern for conv bn fusion -FUSION_PATTERNS = OrderedDict() +DEFAULT_FUSION_PATTERNS = OrderedDict() def register_fusion_pattern(pattern): def insert(fn): - FUSION_PATTERNS[pattern] = fn + DEFAULT_FUSION_PATTERNS[pattern] = fn return fn return insert -def get_fusion_patterns(): - return FUSION_PATTERNS +def get_default_fusion_patterns(): + return DEFAULT_FUSION_PATTERNS -QUANTIZATION_PATTERNS = OrderedDict() +DEFAULT_QUANTIZATION_PATTERNS = OrderedDict() # Register pattern for both static quantization and qat def register_quant_pattern(pattern): def insert(fn): - QUANTIZATION_PATTERNS[pattern] = fn + DEFAULT_QUANTIZATION_PATTERNS[pattern] = fn return fn return insert # Get patterns for both static quantization and qat -def get_quant_patterns(): - return QUANTIZATION_PATTERNS +def get_default_quant_patterns(): + return DEFAULT_QUANTIZATION_PATTERNS # Example use of register pattern function: # @register_fusion_pattern(torch.nn.ReLU, (torch.nn.BatchNorm2d, torch.nn.Conv2d))) diff --git a/torch/quantization/fx/quantize.py b/torch/quantization/fx/quantize.py index f48eb2f2aa94..92b243ef7148 100644 --- a/torch/quantization/fx/quantize.py +++ b/torch/quantization/fx/quantize.py @@ -16,14 +16,14 @@ ) from ..quantization_mappings import ( - get_qat_module_mappings, + get_default_qat_module_mappings, ) from ..quantize import _remove_qconfig from .pattern_utils import ( is_match, - get_quant_patterns, + get_default_quant_patterns, ) from .standalone_module import ( @@ -231,7 +231,7 @@ def __init__(self): def _qat_swap_modules(self, root): - convert(root, mapping=get_qat_module_mappings(), inplace=True, remove_qconfig=False) + convert(root, mapping=get_default_qat_module_mappings(), inplace=True, remove_qconfig=False) def _generate_qconfig_map(self, root, @@ -324,7 +324,7 @@ def _prepare(self, model, qconfig_dict, inplace, prepare_custom_config_dict, is_ prepare_custom_config_dict = {} if not inplace: model = copy.deepcopy(model) - self.patterns = get_quant_patterns() + self.patterns = get_default_quant_patterns() flattened_qconfig_dict = get_flattened_qconfig_dict(qconfig_dict) # TODO: support regex as well diff --git a/torch/quantization/quantization_mappings.py b/torch/quantization/quantization_mappings.py index 87c8f7b0a02b..48c2b2ed2227 100644 --- a/torch/quantization/quantization_mappings.py +++ b/torch/quantization/quantization_mappings.py @@ -11,8 +11,8 @@ from .stubs import QuantStub, DeQuantStub -# Map for swapping float module to quantized ones -STATIC_QUANT_MODULE_MAPPINGS = { +# Default map for swapping float module to quantized ones +DEFAULT_STATIC_QUANT_MODULE_MAPPINGS = { QuantStub: nnq.Quantize, DeQuantStub: nnq.DeQuantize, nn.BatchNorm2d: nnq.BatchNorm2d, @@ -53,8 +53,8 @@ nnqat.Conv2d: nnq.Conv2d, } -# Map for swapping float module to qat modules -QAT_MODULE_MAPPINGS = { +# Default map for swapping float module to qat modules +DEFAULT_QAT_MODULE_MAPPINGS = { nn.Conv2d: nnqat.Conv2d, nn.Linear: nnqat.Linear, # Intrinsic modules: @@ -64,7 +64,7 @@ nni.LinearReLU: nniqat.LinearReLU } -# Map for swapping dynamic modules +# Default map for swapping dynamic modules DYNAMIC_QUANT_MODULE_MAPPINGS = { nn.GRUCell: nnqd.GRUCell, nn.Linear: nnqd.Linear, @@ -81,8 +81,9 @@ nn.Sequential, } -# mapping from floating point function or torch ops to quantized ops -FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS = { +# Default mapping from floating point function or torch ops to quantized ops +# TODO: merge with default static mapping +DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS = { F.elu: torch._ops.ops.quantized.elu, F.hardswish: torch._ops.ops.quantized.hardswish, F.instance_norm: torch._ops.ops.quantized.instance_norm, @@ -90,102 +91,64 @@ F.leaky_relu: torch._ops.ops.quantized.leaky_relu, } -def register_static_quant_module_mapping( - float_source_module_class, static_quant_target_module_class): - ''' Register a mapping from `float_source__module_class` to `static_quant_target_module_class` - `static_quant_target_module_class` must have from_float defined as a class method - The mapping is used in the convert step of post training static quantization to - convert a float module to a statically quantized module. - ''' - assert hasattr(static_quant_target_module_class, 'from_float'), 'from_float must be defined' + \ - ' in quantized module class' - STATIC_QUANT_MODULE_MAPPINGS[float_source_module_class] = static_quant_target_module_class - -def get_static_quant_module_mappings(): +def get_default_static_quant_module_mappings(): ''' Get module mapping for post training static quantization ''' - return STATIC_QUANT_MODULE_MAPPINGS + return DEFAULT_STATIC_QUANT_MODULE_MAPPINGS def get_static_quant_module_class(float_module_class): ''' Get the statically quantized module class corresponding to the floating point module class ''' - static_quant_module_class = STATIC_QUANT_MODULE_MAPPINGS.get(float_module_class, None) + static_quant_module_class = DEFAULT_STATIC_QUANT_MODULE_MAPPINGS.get(float_module_class, None) assert static_quant_module_class is not None, \ 'Floating point module class {}'.format(float_module_class) + \ ' does not have a corresponding quantized module class' return static_quant_module_class -def register_qat_module_mapping(float_source_module_class, qat_target_module_class): - '''Register a mapping from `float_source_module_class` to `qat_target_module_class`, - `qat_target_module_class` must have from_float defined as a class method - This mapping is used in prepare step of quantization aware training to swap - a float module to a qat module. - ''' - assert hasattr(qat_target_module_class, 'from_float'), 'from_float must be defined' + \ - ' in qat module class' - QAT_MODULE_MAPPINGS[float_source_module_class] = qat_target_module_class - -def get_qat_module_mappings(): - ''' Get module mapping for quantization aware training +def get_default_qat_module_mappings(): + ''' Get default module mapping for quantization aware training ''' - return QAT_MODULE_MAPPINGS + return DEFAULT_QAT_MODULE_MAPPINGS -def register_dynamic_quant_module_class(float_source_module_class, dynamic_quant_target_module_class): - ''' Register a mapping from `float_source_module_class` to `dynamic_quant_target_module_class`, - `dynamic_quant_target_module_class` must have from_float defined as a class method - This mapping is used in convert step of post training dynamic - quantization to swap a float module to a dynamically quantized - module. - ''' - assert hasattr(dynamic_quant_target_module_class, 'from_float'), 'from_float must be defined' + \ - ' in dynamically quantized module type' - DYNAMIC_QUANT_MODULE_MAPPINGS[float_source_module_class] = dynamic_quant_target_module_class - -def get_dynamic_quant_module_mappings(): +def get_default_dynamic_quant_module_mappings(): ''' Get module mapping for post training dynamic quantization ''' - return DYNAMIC_QUANT_MODULE_MAPPINGS + return DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS -def get_qconfig_propagation_list(): - ''' Get the list of module types that we'll attach qconfig +def get_default_qconfig_propagation_list(): + ''' Get the default list of module types that we'll attach qconfig attribute to in prepare ''' QCONFIG_PROPAGATE_MODULE_CLASS_LIST = ( - (set(STATIC_QUANT_MODULE_MAPPINGS.keys()) | - set(QAT_MODULE_MAPPINGS.keys()) | - set(DYNAMIC_QUANT_MODULE_MAPPINGS.keys()) | + (set(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS.keys()) | + set(DEFAULT_QAT_MODULE_MAPPINGS.keys()) | + set(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS.keys()) | _INCLUDE_QCONFIG_PROPAGATE_LIST) - _EXCLUDE_QCONFIG_PROPAGATE_LIST ) return QCONFIG_PROPAGATE_MODULE_CLASS_LIST -def get_compare_output_module_list(): +def get_default_compare_output_module_list(): ''' Get list of module class types that we will record output in numeric suite ''' NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_MODULE_LIST = ( - set(STATIC_QUANT_MODULE_MAPPINGS.values()) - | set(QAT_MODULE_MAPPINGS.values()) - | set(DYNAMIC_QUANT_MODULE_MAPPINGS.values()) - | set(STATIC_QUANT_MODULE_MAPPINGS.keys()) - | set(QAT_MODULE_MAPPINGS.keys()) - | set(DYNAMIC_QUANT_MODULE_MAPPINGS.keys()) + set(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS.values()) + | set(DEFAULT_QAT_MODULE_MAPPINGS.values()) + | set(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS.values()) + | set(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS.keys()) + | set(DEFAULT_QAT_MODULE_MAPPINGS.keys()) + | set(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS.keys()) | _INCLUDE_QCONFIG_PROPAGATE_LIST ) - _EXCLUDE_QCONFIG_PROPAGATE_LIST return NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_MODULE_LIST -def register_quantized_operator_mapping(float_op, quantized_op): - ''' Register a mapping from `floating_point_op` (torch or functional) to `quantized_op` - This is used in convert step of fx based graph mode quantization - to convert a float op to quantized op. - ''' - FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS[float_op] = quantized_op - +# TODO: merge with get_static_quant_module_class def get_quantized_operator(float_op): ''' Get the quantized operator corresponding to the float operator ''' - quantized_op = FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS.get(float_op, None) + quantized_op = DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS.get(float_op, None) assert quantized_op is not None, \ 'Operator {} does not have corresponding quantized op'.format(float_op) return quantized_op diff --git a/torch/quantization/quantize.py b/torch/quantization/quantize.py index fedaa7000f99..ac0471dfa4b9 100644 --- a/torch/quantization/quantize.py +++ b/torch/quantization/quantize.py @@ -9,10 +9,10 @@ import torch.nn.quantized as nnq import torch.nn.intrinsic.qat as nniqat -from .quantization_mappings import (get_dynamic_quant_module_mappings, - get_static_quant_module_mappings, - get_qat_module_mappings, - get_qconfig_propagation_list) +from .quantization_mappings import (get_default_dynamic_quant_module_mappings, + get_default_static_quant_module_mappings, + get_default_qat_module_mappings, + get_default_qconfig_propagation_list) from .stubs import DeQuantStub, QuantWrapper from .qconfig import default_dynamic_qconfig, float16_dynamic_qconfig, float_qparams_dynamic_qconfig @@ -37,7 +37,7 @@ def _propagate_qconfig_helper(module, qconfig_dict, allow_list=None, """ # TODO: Add test if allow_list is None: - allow_list = get_qconfig_propagation_list() + allow_list = get_default_qconfig_propagation_list() module_qconfig = qconfig_dict.get(type(module), qconfig_parent) module_qconfig = qconfig_dict.get(prefix, module_qconfig) @@ -93,7 +93,7 @@ def add_observer_(module, qconfig_propagation_list=None, non_leaf_module_list=No None, module is modified inplace with added observer modules and forward_hooks """ if qconfig_propagation_list is None: - qconfig_propagation_list = get_qconfig_propagation_list() + qconfig_propagation_list = get_default_qconfig_propagation_list() if custom_module_class_mapping is None: custom_module_class_mapping = {} @@ -209,7 +209,7 @@ def prepare(model, inplace=False, allow_list=None, qconfig_propagation_list = allow_list if qconfig_propagation_list is None: - qconfig_propagation_list = get_qconfig_propagation_list() + qconfig_propagation_list = get_default_qconfig_propagation_list() propagate_qconfig_(model, qconfig_dict=None) # sanity check common API misusage @@ -255,7 +255,7 @@ def quantize(model, run_fn, run_args, mapping=None, inplace=False): """ torch._C._log_api_usage_once("quantization_api.quantize.quantize") if mapping is None: - mapping = get_static_quant_module_mappings() + mapping = get_default_static_quant_module_mappings() if not inplace: model = copy.deepcopy(model) model.eval() @@ -333,7 +333,7 @@ def quantize_dynamic(model, qconfig_spec=None, dtype=torch.qint8, qconfig_spec = dict(zip(qconfig_spec, itertools.repeat(default_qconfig))) if mapping is None: - mapping = get_dynamic_quant_module_mappings() + mapping = get_default_dynamic_quant_module_mappings() if not inplace: model = copy.deepcopy(model) @@ -359,7 +359,7 @@ def prepare_qat(model, mapping=None, inplace=False): """ torch._C._log_api_usage_once("quantization_api.quantize.prepare_qat") if mapping is None: - mapping = get_qat_module_mappings() + mapping = get_default_qat_module_mappings() if not inplace: model = copy.deepcopy(model) @@ -440,7 +440,7 @@ def _convert( """ if mapping is None: - mapping = get_static_quant_module_mappings() + mapping = get_default_static_quant_module_mappings() if convert_custom_config_dict is None: convert_custom_config_dict = {} custom_module_class_mapping = convert_custom_config_dict.get("observed_to_quantized_custom_module_class", {}) diff --git a/torch/testing/_internal/common_quantization.py b/torch/testing/_internal/common_quantization.py index 3c94e8d8cc3b..90fd86647948 100644 --- a/torch/testing/_internal/common_quantization.py +++ b/torch/testing/_internal/common_quantization.py @@ -14,9 +14,9 @@ propagate_qconfig_, convert, get_default_qconfig, quantize_dynamic_jit, quantize_jit, float_qparams_dynamic_qconfig, \ get_default_qat_qconfig, PerChannelMinMaxObserver, default_dynamic_quant_observer, QConfigDynamic from torch.quantization.quantization_mappings import ( - get_dynamic_quant_module_mappings, - get_qconfig_propagation_list, - get_qat_module_mappings, + get_default_dynamic_quant_module_mappings, + get_default_qconfig_propagation_list, + get_default_qat_module_mappings, ) # symbolic trace from torch.fx import symbolic_trace @@ -186,7 +186,7 @@ def run_ddp(rank, world_size, prepared): def convert_dynamic(module): - convert(module, get_dynamic_quant_module_mappings(), inplace=True) + convert(module, get_default_dynamic_quant_module_mappings(), inplace=True) def prepare_dynamic(model, qconfig_dict=None): propagate_qconfig_(model, qconfig_dict) @@ -342,7 +342,7 @@ def checkObservers(self, module, propagate_qconfig_list=None, prepare_custom_con have observers in preperation for quantization """ if propagate_qconfig_list is None: - propagate_qconfig_list = get_qconfig_propagation_list() + propagate_qconfig_list = get_default_qconfig_propagation_list() if prepare_custom_config_dict is None: prepare_custom_config_dict = {} float_to_observed_module_class_mapping = prepare_custom_config_dict.get("float_to_observed_custom_module_class", {}) @@ -363,7 +363,7 @@ def is_leaf_module(module): 'module: ' + str(type(module)) + ' do not have observer') # we don't need to check observers for child modules of the # qat modules - if type(module) not in get_qat_module_mappings().values() and \ + if type(module) not in get_default_qat_module_mappings().values() and \ type(module) not in float_to_observed_module_class_mapping.values(): for child in module.children(): self.checkObservers(child, propagate_qconfig_list, prepare_custom_config_dict)