Skip to content

Commit

Permalink
[quant][graphmode][fx] Add additional_object_mapping argument to convert
Browse files Browse the repository at this point in the history
Summary:
Should we merge quantized module and quantized operator configurations?

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
  • Loading branch information
jerryzh168 committed Oct 14, 2020
1 parent ce57dad commit fa30163
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 9 deletions.
17 changes: 14 additions & 3 deletions torch/quantization/fx/quantization_patterns.py
Expand Up @@ -177,6 +177,9 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_
# note that relu should already be fused into conv module in the fusion step
assert self.relu_node is None, 'conv module and relu fusion is not executed, ' \
'please make sure to run fusion before prepare'
if convert_custom_config_dict is None:
convert_custom_config_dict = {}
additional_static_quant_mapping = convert_custom_config_dict.get("static", {})
# 1. attach activation post process to module
if type(self.conv) in [
torch.nn.intrinsic.ConvReLU1d,
Expand All @@ -187,7 +190,8 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_
else:
self.conv.activation_post_process = quantizer.activation_post_process_map[node.name]
# 2. select quantized class
qconv_cls = get_static_quant_module_class(type(self.conv))
qconv_cls = get_static_quant_module_class(
type(self.conv), additional_static_quant_mapping)
quantized = qconv_cls.from_float(self.conv)
parent_name, name = _parent_name(self.conv_node.target)
setattr(quantizer.modules[parent_name], name, quantized)
Expand Down Expand Up @@ -363,6 +367,9 @@ def __init__(self, quantizer, node):
self.bn = quantizer.modules[self.bn_node.target]

def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_dict=None):
if convert_custom_config_dict is None:
convert_custom_config_dict = {}
additional_static_quant_mapping = convert_custom_config_dict.get("static", {})
# 1. attach activation post process to module
activation_post_process = quantizer.activation_post_process_map[node.name]
if type(self.bn) in \
Expand All @@ -371,7 +378,7 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_
self.bn[1].activation_post_process = activation_post_process
else:
self.bn.activation_post_process = activation_post_process
qbn_cls = get_static_quant_module_class(type(self.bn))
qbn_cls = get_static_quant_module_class(type(self.bn), additional_static_quant_mapping)
quantized = qbn_cls.from_float(self.bn)
parent_name, name = _parent_name(self.bn_node.target)
setattr(quantizer.modules[parent_name], name, quantized)
Expand Down Expand Up @@ -405,11 +412,15 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_
return NotImplemented
assert node.op in ['call_module', 'call_function'], 'Only call_module and ' + \
'call_function are handled in DefaultNode'
if convert_custom_config_dict is None:
convert_custom_config_dict = {}
additional_static_quant_mapping = convert_custom_config_dict.get("static", {})
activation_post_process = quantizer.activation_post_process_map[node.name]
if node.op == 'call_module':
module = quantizer.modules[node.target]
module.activation_post_process = activation_post_process
quantized_module_cls = get_static_quant_module_class(type(module))
quantized_module_cls = get_static_quant_module_class(
type(module), additional_static_quant_mapping)
quantized_module = quantized_module_cls.from_float(module)
parent_name, name = _parent_name(node.target)
setattr(quantizer.modules[parent_name], name, quantized_module)
Expand Down
15 changes: 9 additions & 6 deletions torch/quantization/quantization_mappings.py
Expand Up @@ -96,14 +96,17 @@ def get_default_static_quant_module_mappings():
'''
return DEFAULT_STATIC_QUANT_MODULE_MAPPINGS

def get_static_quant_module_class(float_module_class):
''' Get the statically quantized module class corresponding to
def get_static_quant_module_class(float_module_class, additional_static_quant_mapping=None):
r"""n Get the statically quantized module class corresponding to
the floating point module class
'''
static_quant_module_class = DEFAULT_STATIC_QUANT_MODULE_MAPPINGS.get(float_module_class, None)
"""
if additional_static_quant_mapping is None:
additional_static_quant_mapping = {}
all_mappings = dict(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS, **additional_static_quant_mapping)
static_quant_module_class = all_mappings.get(float_module_class, None)
assert static_quant_module_class is not None, \
'Floating point module class {}'.format(float_module_class) + \
' does not have a corresponding quantized module class'
"Floating point module class {}".format(float_module_class) + \
" does not have a corresponding quantized module class"
return static_quant_module_class

def get_default_qat_module_mappings():
Expand Down
12 changes: 12 additions & 0 deletions torch/quantization/quantize_fx.py
Expand Up @@ -254,6 +254,18 @@ def convert_fx(graph_module, inplace=False, debug=False, convert_custom_config_d
`debug`: flag for producing a debug friendly model (preserve weight attribute)
`convert_custom_config_dict`: dictionary for custom configurations for convert function:
convert_custom_config_dict = {
# addtional object (module/operator) mappings that will overwrite the default
# module mappingn
"additional_object_mapping": {
"static": {
FloatModule: QuantizedModule,
float_op: quantized_op
},
"dynamic": {
FloatModule: DynamicallyQuantizedModule,
float_op: dynamically_quantized_op
},
}
# user will manually define the corresponding quantized
# module class which has a from_observed class method that converts
# observed custom module to quantized custom module
Expand Down

0 comments on commit fa30163

Please sign in to comment.