From fa30163c1223e4a1c238f356eefba5f34c586542 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 14 Oct 2020 13:50:10 -0700 Subject: [PATCH] [quant][graphmode][fx] Add additional_object_mapping argument to convert Summary: Should we merge quantized module and quantized operator configurations? Test Plan: Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- torch/quantization/fx/quantization_patterns.py | 17 ++++++++++++++--- torch/quantization/quantization_mappings.py | 15 +++++++++------ torch/quantization/quantize_fx.py | 12 ++++++++++++ 3 files changed, 35 insertions(+), 9 deletions(-) diff --git a/torch/quantization/fx/quantization_patterns.py b/torch/quantization/fx/quantization_patterns.py index 656b8f9bec88..f72853510ca0 100644 --- a/torch/quantization/fx/quantization_patterns.py +++ b/torch/quantization/fx/quantization_patterns.py @@ -177,6 +177,9 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_ # note that relu should already be fused into conv module in the fusion step assert self.relu_node is None, 'conv module and relu fusion is not executed, ' \ 'please make sure to run fusion before prepare' + if convert_custom_config_dict is None: + convert_custom_config_dict = {} + additional_static_quant_mapping = convert_custom_config_dict.get("static", {}) # 1. attach activation post process to module if type(self.conv) in [ torch.nn.intrinsic.ConvReLU1d, @@ -187,7 +190,8 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_ else: self.conv.activation_post_process = quantizer.activation_post_process_map[node.name] # 2. select quantized class - qconv_cls = get_static_quant_module_class(type(self.conv)) + qconv_cls = get_static_quant_module_class( + type(self.conv), additional_static_quant_mapping) quantized = qconv_cls.from_float(self.conv) parent_name, name = _parent_name(self.conv_node.target) setattr(quantizer.modules[parent_name], name, quantized) @@ -363,6 +367,9 @@ def __init__(self, quantizer, node): self.bn = quantizer.modules[self.bn_node.target] def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_dict=None): + if convert_custom_config_dict is None: + convert_custom_config_dict = {} + additional_static_quant_mapping = convert_custom_config_dict.get("static", {}) # 1. attach activation post process to module activation_post_process = quantizer.activation_post_process_map[node.name] if type(self.bn) in \ @@ -371,7 +378,7 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_ self.bn[1].activation_post_process = activation_post_process else: self.bn.activation_post_process = activation_post_process - qbn_cls = get_static_quant_module_class(type(self.bn)) + qbn_cls = get_static_quant_module_class(type(self.bn), additional_static_quant_mapping) quantized = qbn_cls.from_float(self.bn) parent_name, name = _parent_name(self.bn_node.target) setattr(quantizer.modules[parent_name], name, quantized) @@ -405,11 +412,15 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_ return NotImplemented assert node.op in ['call_module', 'call_function'], 'Only call_module and ' + \ 'call_function are handled in DefaultNode' + if convert_custom_config_dict is None: + convert_custom_config_dict = {} + additional_static_quant_mapping = convert_custom_config_dict.get("static", {}) activation_post_process = quantizer.activation_post_process_map[node.name] if node.op == 'call_module': module = quantizer.modules[node.target] module.activation_post_process = activation_post_process - quantized_module_cls = get_static_quant_module_class(type(module)) + quantized_module_cls = get_static_quant_module_class( + type(module), additional_static_quant_mapping) quantized_module = quantized_module_cls.from_float(module) parent_name, name = _parent_name(node.target) setattr(quantizer.modules[parent_name], name, quantized_module) diff --git a/torch/quantization/quantization_mappings.py b/torch/quantization/quantization_mappings.py index 48c2b2ed2227..19cc8250d41c 100644 --- a/torch/quantization/quantization_mappings.py +++ b/torch/quantization/quantization_mappings.py @@ -96,14 +96,17 @@ def get_default_static_quant_module_mappings(): ''' return DEFAULT_STATIC_QUANT_MODULE_MAPPINGS -def get_static_quant_module_class(float_module_class): - ''' Get the statically quantized module class corresponding to +def get_static_quant_module_class(float_module_class, additional_static_quant_mapping=None): + r"""n Get the statically quantized module class corresponding to the floating point module class - ''' - static_quant_module_class = DEFAULT_STATIC_QUANT_MODULE_MAPPINGS.get(float_module_class, None) + """ + if additional_static_quant_mapping is None: + additional_static_quant_mapping = {} + all_mappings = dict(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS, **additional_static_quant_mapping) + static_quant_module_class = all_mappings.get(float_module_class, None) assert static_quant_module_class is not None, \ - 'Floating point module class {}'.format(float_module_class) + \ - ' does not have a corresponding quantized module class' + "Floating point module class {}".format(float_module_class) + \ + " does not have a corresponding quantized module class" return static_quant_module_class def get_default_qat_module_mappings(): diff --git a/torch/quantization/quantize_fx.py b/torch/quantization/quantize_fx.py index 11349f081d46..f14ff18c969d 100644 --- a/torch/quantization/quantize_fx.py +++ b/torch/quantization/quantize_fx.py @@ -254,6 +254,18 @@ def convert_fx(graph_module, inplace=False, debug=False, convert_custom_config_d `debug`: flag for producing a debug friendly model (preserve weight attribute) `convert_custom_config_dict`: dictionary for custom configurations for convert function: convert_custom_config_dict = { + # addtional object (module/operator) mappings that will overwrite the default + # module mappingn + "additional_object_mapping": { + "static": { + FloatModule: QuantizedModule, + float_op: quantized_op + }, + "dynamic": { + FloatModule: DynamicallyQuantizedModule, + float_op: dynamically_quantized_op + }, + } # user will manually define the corresponding quantized # module class which has a from_observed class method that converts # observed custom module to quantized custom module