Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[quant][graphmode][fx] Add additional_object_mapping argument to convert #46338

Closed
wants to merge 7 commits into from
17 changes: 14 additions & 3 deletions torch/quantization/fx/quantization_patterns.py
Expand Up @@ -177,6 +177,9 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_
# note that relu should already be fused into conv module in the fusion step
assert self.relu_node is None, 'conv module and relu fusion is not executed, ' \
'please make sure to run fusion before prepare'
if convert_custom_config_dict is None:
convert_custom_config_dict = {}
additional_static_quant_mapping = convert_custom_config_dict.get("static", {})
# 1. attach activation post process to module
if type(self.conv) in [
torch.nn.intrinsic.ConvReLU1d,
Expand All @@ -187,7 +190,8 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_
else:
self.conv.activation_post_process = quantizer.activation_post_process_map[node.name]
# 2. select quantized class
qconv_cls = get_static_quant_module_class(type(self.conv))
qconv_cls = get_static_quant_module_class(
type(self.conv), additional_static_quant_mapping)
quantized = qconv_cls.from_float(self.conv)
parent_name, name = _parent_name(self.conv_node.target)
setattr(quantizer.modules[parent_name], name, quantized)
Expand Down Expand Up @@ -363,6 +367,9 @@ def __init__(self, quantizer, node):
self.bn = quantizer.modules[self.bn_node.target]

def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_dict=None):
if convert_custom_config_dict is None:
convert_custom_config_dict = {}
additional_static_quant_mapping = convert_custom_config_dict.get("static", {})
# 1. attach activation post process to module
activation_post_process = quantizer.activation_post_process_map[node.name]
if type(self.bn) in \
Expand All @@ -371,7 +378,7 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_
self.bn[1].activation_post_process = activation_post_process
else:
self.bn.activation_post_process = activation_post_process
qbn_cls = get_static_quant_module_class(type(self.bn))
qbn_cls = get_static_quant_module_class(type(self.bn), additional_static_quant_mapping)
quantized = qbn_cls.from_float(self.bn)
parent_name, name = _parent_name(self.bn_node.target)
setattr(quantizer.modules[parent_name], name, quantized)
Expand Down Expand Up @@ -405,11 +412,15 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_
return NotImplemented
assert node.op in ['call_module', 'call_function'], 'Only call_module and ' + \
'call_function are handled in DefaultNode'
if convert_custom_config_dict is None:
convert_custom_config_dict = {}
additional_static_quant_mapping = convert_custom_config_dict.get("static", {})
activation_post_process = quantizer.activation_post_process_map[node.name]
if node.op == 'call_module':
module = quantizer.modules[node.target]
module.activation_post_process = activation_post_process
quantized_module_cls = get_static_quant_module_class(type(module))
quantized_module_cls = get_static_quant_module_class(
type(module), additional_static_quant_mapping)
quantized_module = quantized_module_cls.from_float(module)
parent_name, name = _parent_name(node.target)
setattr(quantizer.modules[parent_name], name, quantized_module)
Expand Down
15 changes: 9 additions & 6 deletions torch/quantization/quantization_mappings.py
Expand Up @@ -96,14 +96,17 @@ def get_default_static_quant_module_mappings():
'''
return DEFAULT_STATIC_QUANT_MODULE_MAPPINGS

def get_static_quant_module_class(float_module_class):
''' Get the statically quantized module class corresponding to
def get_static_quant_module_class(float_module_class, additional_static_quant_mapping=None):
r"""n Get the statically quantized module class corresponding to
the floating point module class
'''
static_quant_module_class = DEFAULT_STATIC_QUANT_MODULE_MAPPINGS.get(float_module_class, None)
"""
if additional_static_quant_mapping is None:
additional_static_quant_mapping = {}
all_mappings = dict(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS, **additional_static_quant_mapping)
static_quant_module_class = all_mappings.get(float_module_class, None)
assert static_quant_module_class is not None, \
'Floating point module class {}'.format(float_module_class) + \
' does not have a corresponding quantized module class'
"Floating point module class {}".format(float_module_class) + \
" does not have a corresponding quantized module class"
return static_quant_module_class

def get_default_qat_module_mappings():
Expand Down
12 changes: 12 additions & 0 deletions torch/quantization/quantize_fx.py
Expand Up @@ -253,6 +253,18 @@ def convert_fx(graph_module, inplace=False, debug=False, convert_custom_config_d
`debug`: flag for producing a debug friendly model (preserve weight attribute)
`convert_custom_config_dict`: dictionary for custom configurations for convert function:
convert_custom_config_dict = {
# addtional object (module/operator) mappings that will overwrite the default
# module mappingn
"additional_object_mapping": {
"static": {
FloatModule: QuantizedModule,
float_op: quantized_op
},
"dynamic": {
FloatModule: DynamicallyQuantizedModule,
float_op: dynamically_quantized_op
},
Comment on lines +259 to +266
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how would custom mappings for QAT, and/or for models which have some parts quantized dynamically and statically fit?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you mean additional qat module mappings? it's added in the next PR. it's kind of independent from these, it happens before we insert observer, e.g. we swap nn.Conv2d to nn.qat.Conv2d

"static" and "dynamic" here are used during convert, e.g. nn.qat.Conv2d to nn.quantized.Conv2d

}
# user will manually define the corresponding quantized
# module class which has a from_observed class method that converts
# observed custom module to quantized custom module
Expand Down