Skip to content

Commit

Permalink
Revert D24290811: [quant][eagermode] Move custom_module registration …
Browse files Browse the repository at this point in the history
…to prepare/convert_custom_config_dict

Test Plan: revert-hammer

Differential Revision:
D24290811 (3ad797c)

Original commit changeset: 7d2aee98e194

fbshipit-source-id: 24013e92044f2a1b36b1a9f475bbaa6f17bdaa11
  • Loading branch information
Mike Ruberry authored and facebook-github-bot committed Oct 14, 2020
1 parent a38eeef commit ff0af72
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 81 deletions.
25 changes: 8 additions & 17 deletions test/quantization/test_quantize.py
Expand Up @@ -23,6 +23,8 @@
per_channel_dynamic_qconfig,
float16_dynamic_qconfig,
float_qparams_dynamic_qconfig,
register_observed_custom_module_mapping,
register_quantized_custom_module_mapping,
PerChannelMinMaxObserver,
QConfigDynamic,
default_dynamic_quant_observer
Expand Down Expand Up @@ -625,6 +627,9 @@ def from_observed(cls, observed_module):
quantized = cls(nnq.Conv2d.from_float(observed_module.conv))
return quantized

register_observed_custom_module_mapping(CustomModule, ObservedCustomModule)
register_quantized_custom_module_mapping(CustomModule, QuantizedCustomModule)

class M(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -665,28 +670,14 @@ def forward(self, x):
original_ref_m.conv2.bias = torch.nn.Parameter(original_m.custom.conv.bias.detach())

original_m.qconfig = default_qconfig
prepare_custom_config_dict = {
"float_to_observed_custom_module_class": {
CustomModule: ObservedCustomModule
}
}
convert_custom_config_dict = {
"observed_to_quantized_custom_module_class": {
ObservedCustomModule: QuantizedCustomModule
}
}
m = prepare(
original_m,
prepare_custom_config_dict=prepare_custom_config_dict)
self.checkObservers(m, None, prepare_custom_config_dict)
m = prepare(original_m)
self.checkObservers(m)
# calibration
m(data)
# all activation observers are inserted in the top level module

# check converted/quantized model
m = convert(
m,
convert_custom_config_dict=convert_custom_config_dict)
m = convert(m)
# check if the module is properly quantized
self.assertEqual(type(m.quant), nnq.Quantize)
self.assertEqual(type(m.conv), nnq.Conv2d)
Expand Down
7 changes: 7 additions & 0 deletions torch/quantization/__init__.py
Expand Up @@ -9,6 +9,7 @@
from .quantize_fx import *
from .quantization_mappings import *
from .fuser_method_mappings import *
from .custom_module_class_mappings import *

def default_eval_fn(model, calib_data):
r"""
Expand Down Expand Up @@ -40,6 +41,12 @@ def default_eval_fn(model, calib_data):
'get_compare_output_module_list',
'register_quantized_operator_mapping', 'get_quantized_operator',
'register_fuser_method', 'get_fuser_method',
'register_observed_custom_module_mapping',
'get_observed_custom_module_class',
'register_quantized_custom_mdoule_mapping',
'get_quantized_custom_module_class',
'is_custom_module_class',
'is_observed_custom_module',
# Sub functions for `prepare` and `swap_module`
'propagate_qconfig_', 'add_quant_dequant', 'add_observer_', 'swap_module',
'default_eval_fn', 'get_observer_dict',
Expand Down
75 changes: 75 additions & 0 deletions torch/quantization/custom_module_class_mappings.py
@@ -0,0 +1,75 @@
OBSERVED_CUSTOM_MODULE_CLASS_MAPPINGS = dict()

def register_observed_custom_module_mapping(float_custom_module_class, observed_custom_module_class):
""" Register a mapping from `float_custom_module_class` to
`observed_custom_module_class`
`observed_custom_module_class` will have a `from_float` classmethod,
which will return an observed custom module instance given
a float custom module instance.
This will be used in prepare step of post training static quantization or
quantization aware training
"""
assert hasattr(observed_custom_module_class, 'from_float'), 'from_float must be' + \
' defined in observed custom module class'
OBSERVED_CUSTOM_MODULE_CLASS_MAPPINGS[float_custom_module_class] = \
observed_custom_module_class

def get_observed_custom_module_class(float_custom_module_class):
""" Get the corresponding observed module class for a given
float custom module.
"""
observed_custom_module_class = \
OBSERVED_CUSTOM_MODULE_CLASS_MAPPINGS.get(float_custom_module_class, None)
assert observed_custom_module_class is not None, \
'Float Custom module class {}'.format(float_custom_module_class) + \
' does not have a corresponding observed module class'
return observed_custom_module_class

QUANTIZED_CUSTOM_MODULE_CLASS_MAPPINGS = dict()

def register_quantized_custom_module_mapping(float_custom_module_class, quantized_custom_module_class):
""" Register a mapping from `float_custom_module_class` to `quantized_custom_module_class`
A quantized custom module class should accept quantized input and
return quantized output. (we can relax this condition in the
future if there is a need)
`quantized_custom_module_class` will have a `from_observed` classmethod,
which will return an quantized custom module instance given
a observed custom module instance.
This will be used in prepare step of post training static quantization or
quantization aware training
"""
assert hasattr(quantized_custom_module_class, 'from_observed'), 'from_observed' + \
' must be defined in quantized custom module class'
QUANTIZED_CUSTOM_MODULE_CLASS_MAPPINGS[float_custom_module_class] = \
quantized_custom_module_class

def get_quantized_custom_module_class(float_custom_module_class):
""" Get the corresponding quantized module class for a given
float custom module.
"""
quantized_custom_module_class = \
QUANTIZED_CUSTOM_MODULE_CLASS_MAPPINGS.get(float_custom_module_class, None)
assert quantized_custom_module_class is not None, \
'Float Custom module class {}'.format(float_custom_module_class) + \
' does not have a corresponding quantized module class'
return quantized_custom_module_class

def is_custom_module_class(module_class):
""" Check if a given module class is a custom module class
"""
return module_class in OBSERVED_CUSTOM_MODULE_CLASS_MAPPINGS and \
module_class in QUANTIZED_CUSTOM_MODULE_CLASS_MAPPINGS

def mark_observed_custom_module(module, custom_module_class):
""" Mark a module as observed custom module, so that
it can be identified during convert step
"""
module._is_observed_custom_module = True
module._FLOAT_MODULE = custom_module_class

def is_observed_custom_module(module):
""" Check if a module is marked as observed custom module
or not
"""
return hasattr(module, '_is_observed_custom_module') and \
module._is_observed_custom_module
79 changes: 25 additions & 54 deletions torch/quantization/quantize.py
Expand Up @@ -14,6 +14,14 @@
get_qat_module_mappings,
get_qconfig_propagation_list)

from .custom_module_class_mappings import (
is_custom_module_class,
get_observed_custom_module_class,
get_quantized_custom_module_class,
mark_observed_custom_module,
is_observed_custom_module,
)

from .stubs import DeQuantStub, QuantWrapper
from .qconfig import default_dynamic_qconfig, float16_dynamic_qconfig, float_qparams_dynamic_qconfig

Expand Down Expand Up @@ -78,7 +86,7 @@ def register_activation_post_process_hook(module):
'Expect activation_post_process attribut already attached to the module'
return module.register_forward_hook(_observer_forward_hook)

def add_observer_(module, qconfig_propagation_list=None, non_leaf_module_list=None, device=None, custom_module_class_mapping=None):
def add_observer_(module, qconfig_propagation_list=None, non_leaf_module_list=None, device=None):
r"""Add observer for the leaf child of the module.
This function insert observer module to all leaf child module that
Expand All @@ -95,9 +103,6 @@ def add_observer_(module, qconfig_propagation_list=None, non_leaf_module_list=No
if qconfig_propagation_list is None:
qconfig_propagation_list = get_qconfig_propagation_list()

if custom_module_class_mapping is None:
custom_module_class_mapping = {}

# respect device affinity when adding observers
if device is None:
devices = get_unique_devices_(module)
Expand Down Expand Up @@ -134,12 +139,13 @@ def insert_activation_post_process(m):
child.activation_post_process = get_activation_post_process(child.qconfig, device)
elif non_leaf_module_list is not None and type(child) in non_leaf_module_list:
insert_activation_post_process(child)
elif needs_observation(child) and type(child) in custom_module_class_mapping:
observed_child = custom_module_class_mapping[type(child)].from_float(child)
elif needs_observation(child) and is_custom_module_class(type(child)):
observed_child = get_observed_custom_module_class(type(child)).from_float(child)
mark_observed_custom_module(observed_child, type(child))
setattr(module, name, observed_child)
insert_activation_post_process(observed_child)
else:
add_observer_(child, qconfig_propagation_list, non_leaf_module_list, device, custom_module_class_mapping)
add_observer_(child, qconfig_propagation_list, non_leaf_module_list, device)

# Insert observers only for leaf nodes, note that this observer is for
# the output of the module, for input QuantStub will observe them
Expand Down Expand Up @@ -174,8 +180,7 @@ def add_quant_dequant(module):
return module

def prepare(model, inplace=False, allow_list=None,
observer_non_leaf_module_list=None,
prepare_custom_config_dict=None):
observer_non_leaf_module_list=None):
r"""Prepares a copy of the model for quantization calibration or quantization-aware training.
Quantization configuration should be assigned preemptively
Expand All @@ -189,21 +194,8 @@ def prepare(model, inplace=False, allow_list=None,
inplace: carry out model transformations in-place, the original module is mutated
allow_list: list of quantizable modules
observer_non_leaf_module_list: list of non-leaf modules we want to add observer
`prepare_custom_config_dict`: customization configuration dictionary for prepare function:
# user will manually define the corresponding observed
# module class which has a from_float class method that converts
# float custom module to observed custom module
prepare_custom_config_dict = {
"float_to_observed_custom_module_class": {
CustomModule: ObservedCustomModule
}
}
"""
torch._C._log_api_usage_once("quantization_api.quantize.prepare")
if prepare_custom_config_dict is None:
prepare_custom_config_dict = {}
custom_module_class_mapping = prepare_custom_config_dict.get("float_to_observed_custom_module_class", {})

if not inplace:
model = copy.deepcopy(model)

Expand All @@ -218,9 +210,7 @@ def prepare(model, inplace=False, allow_list=None,
"passed correct configuration through `qconfig_dict` or "
"by assigning the `.qconfig` attribute directly on submodules")

add_observer_(
model, qconfig_propagation_list, observer_non_leaf_module_list,
custom_module_class_mapping=custom_module_class_mapping)
add_observer_(model, qconfig_propagation_list, observer_non_leaf_module_list)
return model

def _remove_qconfig(module):
Expand Down Expand Up @@ -390,9 +380,7 @@ def quantize_qat(model, run_fn, run_args, inplace=False):
convert(model, inplace=True)
return model

def convert(
module, mapping=None, inplace=False, remove_qconfig=True,
convert_custom_config_dict=None):
def convert(module, mapping=None, inplace=False, remove_qconfig=True):
r"""Converts submodules in input module to a different module according to `mapping`
by calling `from_float` method on the target module class. And remove qconfig at the
end if remove_qconfig is set to True.
Expand All @@ -404,29 +392,17 @@ def convert(
Modules
inplace: carry out model transformations in-place, the original module
is mutated
`convert_custom_config_dict`: custom configuration dictionary for convert function:
convert_custom_config_dict = {
# user will manually define the corresponding quantized
# module class which has a from_observed class method that converts
# observed custom module to quantized custom module
"observed_to_quantized_custom_module_class": {
ObservedCustomModule: QuantizedCustomModule
}
}
"""
torch._C._log_api_usage_once("quantization_api.quantize.convert")
if not inplace:
module = copy.deepcopy(module)
_convert(
module, mapping, inplace=True,
convert_custom_config_dict=convert_custom_config_dict)
_convert(module, mapping, inplace=True)
if remove_qconfig:
_remove_qconfig(module)
return module

def _convert(
module, mapping=None, inplace=False,
convert_custom_config_dict=None):
def _convert(module, mapping=None, inplace=False):
r"""Converts submodules in input module to a different module according to `mapping`
by calling `from_float` method on the target module class
Expand All @@ -441,10 +417,6 @@ def _convert(
"""
if mapping is None:
mapping = get_static_quant_module_mappings()
if convert_custom_config_dict is None:
convert_custom_config_dict = {}
custom_module_class_mapping = convert_custom_config_dict.get("observed_to_quantized_custom_module_class", {})

if not inplace:
module = copy.deepcopy(module)
reassign = {}
Expand All @@ -468,17 +440,16 @@ def _convert(
# both swappable modules and observed custom modules are
# swapped as one unit
if type(mod) not in SWAPPABLE_MODULES and \
type(mod) not in custom_module_class_mapping:
_convert(mod, mapping, True, # inplace
custom_module_class_mapping)
reassign[name] = swap_module(mod, mapping, custom_module_class_mapping)
not is_observed_custom_module(mod):
_convert(mod, mapping, inplace=True)
reassign[name] = swap_module(mod, mapping)

for key, value in reassign.items():
module._modules[key] = value

return module

def swap_module(mod, mapping, custom_module_class_mapping):
def swap_module(mod, mapping):
r"""Swaps the module if it has a quantized counterpart and it has an
`observer` attached.
Expand All @@ -493,8 +464,8 @@ def swap_module(mod, mapping, custom_module_class_mapping):
# Always replace dequantstub with dequantize
if hasattr(mod, 'qconfig') and mod.qconfig is not None or type(mod) == DeQuantStub:
swapped = False
if type(mod) in custom_module_class_mapping:
new_mod = custom_module_class_mapping[type(mod)].from_observed(mod)
if is_observed_custom_module(mod):
new_mod = get_quantized_custom_module_class(mod._FLOAT_MODULE).from_observed(mod)
swapped = True
elif type(mod) in mapping:
new_mod = mapping[type(mod)].from_float(mod)
Expand Down
21 changes: 11 additions & 10 deletions torch/testing/_internal/common_quantization.py
Expand Up @@ -13,6 +13,10 @@
default_qconfig, default_dynamic_qconfig, default_per_channel_qconfig, QConfig, default_observer, default_weight_observer, \
propagate_qconfig_, convert, get_default_qconfig, quantize_dynamic_jit, quantize_jit, float_qparams_dynamic_qconfig, \
get_default_qat_qconfig, PerChannelMinMaxObserver, default_dynamic_quant_observer, QConfigDynamic
from torch.quantization import (
is_custom_module_class,
is_observed_custom_module,
)
from torch.quantization.quantization_mappings import (
get_dynamic_quant_module_mappings,
get_qconfig_propagation_list,
Expand Down Expand Up @@ -337,15 +341,12 @@ def checkHasPrepModules(self, module):
self.assertTrue(hasattr(module, 'quant'))
self.assertTrue(hasattr(module, 'dequant'))

def checkObservers(self, module, propagate_qconfig_list=None, prepare_custom_config_dict=None):
def checkObservers(self, module, propagate_qconfig_list=None):
r"""Checks the module or module's leaf descendants
have observers in preperation for quantization
"""
if propagate_qconfig_list is None:
propagate_qconfig_list = get_qconfig_propagation_list()
if prepare_custom_config_dict is None:
prepare_custom_config_dict = {}
float_to_observed_module_class_mapping = prepare_custom_config_dict.get("float_to_observed_custom_module_class", {})

# check if a module is a leaf module, ignoring activation_post_process attribute
def is_leaf_module(module):
Expand All @@ -355,18 +356,18 @@ def is_leaf_module(module):
submodule_name_count += 1
return submodule_name_count == 0

if hasattr(module, 'qconfig') and module.qconfig is not None and \
((is_leaf_module(module) and not isinstance(module, torch.nn.Sequential)
and type(module) in propagate_qconfig_list) or
type(module) in float_to_observed_module_class_mapping.keys()):
if (hasattr(module, 'qconfig') and module.qconfig is not None and
is_leaf_module(module) and not isinstance(module, torch.nn.Sequential)
and type(module) in propagate_qconfig_list) or \
is_custom_module_class(type(module)):
self.assertTrue(hasattr(module, 'activation_post_process'),
'module: ' + str(type(module)) + ' do not have observer')
# we don't need to check observers for child modules of the
# qat modules
if type(module) not in get_qat_module_mappings().values() and \
type(module) not in float_to_observed_module_class_mapping.values():
not is_observed_custom_module(module):
for child in module.children():
self.checkObservers(child, propagate_qconfig_list, prepare_custom_config_dict)
self.checkObservers(child)

def checkQuantDequant(self, mod):
r"""Checks that mod has nn.Quantize and
Expand Down

0 comments on commit ff0af72

Please sign in to comment.