Skip to content

Commit

Permalink
[reland][quant][eagermode] Move custom_module registration to prepare…
Browse files Browse the repository at this point in the history
…/convert_custom_config_dict (#46293)

Summary:

Test Plan: Imported from OSS

Reviewed By: raghuramank100

ghstack-source-id: 60ab2608d0340b54fae28f3cf855f004e312fd20
Pull Request resolved: #46364
  • Loading branch information
jerryzh168 committed Oct 15, 2020
1 parent ff0af72 commit fd05b9c
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 135 deletions.
25 changes: 17 additions & 8 deletions test/quantization/test_quantize.py
Expand Up @@ -23,8 +23,6 @@
per_channel_dynamic_qconfig,
float16_dynamic_qconfig,
float_qparams_dynamic_qconfig,
register_observed_custom_module_mapping,
register_quantized_custom_module_mapping,
PerChannelMinMaxObserver,
QConfigDynamic,
default_dynamic_quant_observer
Expand Down Expand Up @@ -627,9 +625,6 @@ def from_observed(cls, observed_module):
quantized = cls(nnq.Conv2d.from_float(observed_module.conv))
return quantized

register_observed_custom_module_mapping(CustomModule, ObservedCustomModule)
register_quantized_custom_module_mapping(CustomModule, QuantizedCustomModule)

class M(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -670,14 +665,28 @@ def forward(self, x):
original_ref_m.conv2.bias = torch.nn.Parameter(original_m.custom.conv.bias.detach())

original_m.qconfig = default_qconfig
m = prepare(original_m)
self.checkObservers(m)
prepare_custom_config_dict = {
"float_to_observed_custom_module_class": {
CustomModule: ObservedCustomModule
}
}
convert_custom_config_dict = {
"observed_to_quantized_custom_module_class": {
ObservedCustomModule: QuantizedCustomModule
}
}
m = prepare(
original_m,
prepare_custom_config_dict=prepare_custom_config_dict)
self.checkObservers(m, None, prepare_custom_config_dict)
# calibration
m(data)
# all activation observers are inserted in the top level module

# check converted/quantized model
m = convert(m)
m = convert(
m,
convert_custom_config_dict=convert_custom_config_dict)
# check if the module is properly quantized
self.assertEqual(type(m.quant), nnq.Quantize)
self.assertEqual(type(m.conv), nnq.Conv2d)
Expand Down
7 changes: 0 additions & 7 deletions torch/quantization/__init__.py
Expand Up @@ -9,7 +9,6 @@
from .quantize_fx import *
from .quantization_mappings import *
from .fuser_method_mappings import *
from .custom_module_class_mappings import *

def default_eval_fn(model, calib_data):
r"""
Expand Down Expand Up @@ -41,12 +40,6 @@ def default_eval_fn(model, calib_data):
'get_compare_output_module_list',
'register_quantized_operator_mapping', 'get_quantized_operator',
'register_fuser_method', 'get_fuser_method',
'register_observed_custom_module_mapping',
'get_observed_custom_module_class',
'register_quantized_custom_mdoule_mapping',
'get_quantized_custom_module_class',
'is_custom_module_class',
'is_observed_custom_module',
# Sub functions for `prepare` and `swap_module`
'propagate_qconfig_', 'add_quant_dequant', 'add_observer_', 'swap_module',
'default_eval_fn', 'get_observer_dict',
Expand Down
75 changes: 0 additions & 75 deletions torch/quantization/custom_module_class_mappings.py

This file was deleted.

107 changes: 73 additions & 34 deletions torch/quantization/quantize.py
Expand Up @@ -14,14 +14,6 @@
get_qat_module_mappings,
get_qconfig_propagation_list)

from .custom_module_class_mappings import (
is_custom_module_class,
get_observed_custom_module_class,
get_quantized_custom_module_class,
mark_observed_custom_module,
is_observed_custom_module,
)

from .stubs import DeQuantStub, QuantWrapper
from .qconfig import default_dynamic_qconfig, float16_dynamic_qconfig, float_qparams_dynamic_qconfig

Expand Down Expand Up @@ -86,7 +78,7 @@ def register_activation_post_process_hook(module):
'Expect activation_post_process attribut already attached to the module'
return module.register_forward_hook(_observer_forward_hook)

def add_observer_(module, qconfig_propagation_list=None, non_leaf_module_list=None, device=None):
def add_observer_(module, qconfig_propagation_list=None, non_leaf_module_list=None, device=None, custom_module_class_mapping=None):
r"""Add observer for the leaf child of the module.
This function insert observer module to all leaf child module that
Expand All @@ -103,6 +95,9 @@ def add_observer_(module, qconfig_propagation_list=None, non_leaf_module_list=No
if qconfig_propagation_list is None:
qconfig_propagation_list = get_qconfig_propagation_list()

if custom_module_class_mapping is None:
custom_module_class_mapping = {}

# respect device affinity when adding observers
if device is None:
devices = get_unique_devices_(module)
Expand Down Expand Up @@ -139,13 +134,12 @@ def insert_activation_post_process(m):
child.activation_post_process = get_activation_post_process(child.qconfig, device)
elif non_leaf_module_list is not None and type(child) in non_leaf_module_list:
insert_activation_post_process(child)
elif needs_observation(child) and is_custom_module_class(type(child)):
observed_child = get_observed_custom_module_class(type(child)).from_float(child)
mark_observed_custom_module(observed_child, type(child))
elif needs_observation(child) and type(child) in custom_module_class_mapping:
observed_child = custom_module_class_mapping[type(child)].from_float(child)
setattr(module, name, observed_child)
insert_activation_post_process(observed_child)
else:
add_observer_(child, qconfig_propagation_list, non_leaf_module_list, device)
add_observer_(child, qconfig_propagation_list, non_leaf_module_list, device, custom_module_class_mapping)

# Insert observers only for leaf nodes, note that this observer is for
# the output of the module, for input QuantStub will observe them
Expand Down Expand Up @@ -180,7 +174,8 @@ def add_quant_dequant(module):
return module

def prepare(model, inplace=False, allow_list=None,
observer_non_leaf_module_list=None):
observer_non_leaf_module_list=None,
prepare_custom_config_dict=None):
r"""Prepares a copy of the model for quantization calibration or quantization-aware training.
Quantization configuration should be assigned preemptively
Expand All @@ -190,12 +185,30 @@ def prepare(model, inplace=False, allow_list=None,
will be propagated.
Args:
model: input model to be modified in-place
inplace: carry out model transformations in-place, the original module is mutated
allow_list: list of quantizable modules
observer_non_leaf_module_list: list of non-leaf modules we want to add observer
`model`: input model to be modified in-place
`inplace`: carry out model transformations in-place, the original module is mutated
`allow_list`: list of quantizable modules
`observer_non_leaf_module_list`: list of non-leaf modules we want to add observer
`prepare_custom_config_dict`: customization configuration dictionary for prepare function
Example:
```python
prepare_custom_config_dict = {
# user will manually define the corresponding observed
# module class which has a from_float class method that converts
# float custom module to observed custom module
"float_to_observed_custom_module_class": {
CustomModule: ObservedCustomModule
}
}
```
"""
torch._C._log_api_usage_once("quantization_api.quantize.prepare")
if prepare_custom_config_dict is None:
prepare_custom_config_dict = {}
custom_module_class_mapping = prepare_custom_config_dict.get("float_to_observed_custom_module_class", {})

if not inplace:
model = copy.deepcopy(model)

Expand All @@ -210,7 +223,9 @@ def prepare(model, inplace=False, allow_list=None,
"passed correct configuration through `qconfig_dict` or "
"by assigning the `.qconfig` attribute directly on submodules")

add_observer_(model, qconfig_propagation_list, observer_non_leaf_module_list)
add_observer_(
model, qconfig_propagation_list, observer_non_leaf_module_list,
custom_module_class_mapping=custom_module_class_mapping)
return model

def _remove_qconfig(module):
Expand Down Expand Up @@ -380,29 +395,48 @@ def quantize_qat(model, run_fn, run_args, inplace=False):
convert(model, inplace=True)
return model

def convert(module, mapping=None, inplace=False, remove_qconfig=True):
def convert(
module, mapping=None, inplace=False, remove_qconfig=True,
convert_custom_config_dict=None):
r"""Converts submodules in input module to a different module according to `mapping`
by calling `from_float` method on the target module class. And remove qconfig at the
end if remove_qconfig is set to True.
Args:
module: input module
mapping: a dictionary that maps from source module type to target
module type, can be overwritten to allow swapping user defined
Modules
inplace: carry out model transformations in-place, the original module
is mutated
`module`: input module
`mapping`: a dictionary that maps from source module type to target
module type, can be overwritten to allow swapping user defined
Modules
`inplace`: carry out model transformations in-place, the original module
is mutated
`convert_custom_config_dict`: custom configuration dictionary for convert function
Example:
```python
convert_custom_config_dict = {
# user will manually define the corresponding quantized
# module class which has a from_observed class method that converts
# observed custom module to quantized custom module
"observed_to_quantized_custom_module_class": {
ObservedCustomModule: QuantizedCustomModule
}
}
```
"""
torch._C._log_api_usage_once("quantization_api.quantize.convert")
if not inplace:
module = copy.deepcopy(module)
_convert(module, mapping, inplace=True)
_convert(
module, mapping, inplace=True,
convert_custom_config_dict=convert_custom_config_dict)
if remove_qconfig:
_remove_qconfig(module)
return module

def _convert(module, mapping=None, inplace=False):
def _convert(
module, mapping=None, inplace=False,
convert_custom_config_dict=None):
r"""Converts submodules in input module to a different module according to `mapping`
by calling `from_float` method on the target module class
Expand All @@ -417,6 +451,10 @@ def _convert(module, mapping=None, inplace=False):
"""
if mapping is None:
mapping = get_static_quant_module_mappings()
if convert_custom_config_dict is None:
convert_custom_config_dict = {}
custom_module_class_mapping = convert_custom_config_dict.get("observed_to_quantized_custom_module_class", {})

if not inplace:
module = copy.deepcopy(module)
reassign = {}
Expand All @@ -440,16 +478,17 @@ def _convert(module, mapping=None, inplace=False):
# both swappable modules and observed custom modules are
# swapped as one unit
if type(mod) not in SWAPPABLE_MODULES and \
not is_observed_custom_module(mod):
_convert(mod, mapping, inplace=True)
reassign[name] = swap_module(mod, mapping)
type(mod) not in custom_module_class_mapping:
_convert(mod, mapping, True, # inplace
custom_module_class_mapping)
reassign[name] = swap_module(mod, mapping, custom_module_class_mapping)

for key, value in reassign.items():
module._modules[key] = value

return module

def swap_module(mod, mapping):
def swap_module(mod, mapping, custom_module_class_mapping):
r"""Swaps the module if it has a quantized counterpart and it has an
`observer` attached.
Expand All @@ -464,8 +503,8 @@ def swap_module(mod, mapping):
# Always replace dequantstub with dequantize
if hasattr(mod, 'qconfig') and mod.qconfig is not None or type(mod) == DeQuantStub:
swapped = False
if is_observed_custom_module(mod):
new_mod = get_quantized_custom_module_class(mod._FLOAT_MODULE).from_observed(mod)
if type(mod) in custom_module_class_mapping:
new_mod = custom_module_class_mapping[type(mod)].from_observed(mod)
swapped = True
elif type(mod) in mapping:
new_mod = mapping[type(mod)].from_float(mod)
Expand Down

0 comments on commit fd05b9c

Please sign in to comment.