Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[quant][eagermode] Move custom_module registration to prepare/convert_custom_config_dict #46293

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
25 changes: 17 additions & 8 deletions test/quantization/test_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
per_channel_dynamic_qconfig,
float16_dynamic_qconfig,
float_qparams_dynamic_qconfig,
register_observed_custom_module_mapping,
register_quantized_custom_module_mapping,
PerChannelMinMaxObserver,
QConfigDynamic,
default_dynamic_quant_observer
Expand Down Expand Up @@ -627,9 +625,6 @@ def from_observed(cls, observed_module):
quantized = cls(nnq.Conv2d.from_float(observed_module.conv))
return quantized

register_observed_custom_module_mapping(CustomModule, ObservedCustomModule)
register_quantized_custom_module_mapping(CustomModule, QuantizedCustomModule)

class M(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -670,14 +665,28 @@ def forward(self, x):
original_ref_m.conv2.bias = torch.nn.Parameter(original_m.custom.conv.bias.detach())

original_m.qconfig = default_qconfig
m = prepare(original_m)
self.checkObservers(m)
prepare_custom_config_dict = {
"float_to_observed_custom_module_class": {
CustomModule: ObservedCustomModule
}
}
convert_custom_config_dict = {
"observed_to_quantized_custom_module_class": {
ObservedCustomModule: QuantizedCustomModule
}
}
m = prepare(
original_m,
prepare_custom_config_dict=prepare_custom_config_dict)
self.checkObservers(m, None, prepare_custom_config_dict)
# calibration
m(data)
# all activation observers are inserted in the top level module

# check converted/quantized model
m = convert(m)
m = convert(
m,
convert_custom_config_dict=convert_custom_config_dict)
# check if the module is properly quantized
self.assertEqual(type(m.quant), nnq.Quantize)
self.assertEqual(type(m.conv), nnq.Conv2d)
Expand Down
7 changes: 0 additions & 7 deletions torch/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from .quantize_fx import *
from .quantization_mappings import *
from .fuser_method_mappings import *
from .custom_module_class_mappings import *

def default_eval_fn(model, calib_data):
r"""
Expand Down Expand Up @@ -41,12 +40,6 @@ def default_eval_fn(model, calib_data):
'get_compare_output_module_list',
'register_quantized_operator_mapping', 'get_quantized_operator',
'register_fuser_method', 'get_fuser_method',
'register_observed_custom_module_mapping',
'get_observed_custom_module_class',
'register_quantized_custom_mdoule_mapping',
'get_quantized_custom_module_class',
'is_custom_module_class',
'is_observed_custom_module',
# Sub functions for `prepare` and `swap_module`
'propagate_qconfig_', 'add_quant_dequant', 'add_observer_', 'swap_module',
'default_eval_fn', 'get_observer_dict',
Expand Down
75 changes: 0 additions & 75 deletions torch/quantization/custom_module_class_mappings.py

This file was deleted.

79 changes: 54 additions & 25 deletions torch/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,6 @@
get_qat_module_mappings,
get_qconfig_propagation_list)

from .custom_module_class_mappings import (
is_custom_module_class,
get_observed_custom_module_class,
get_quantized_custom_module_class,
mark_observed_custom_module,
is_observed_custom_module,
)

from .stubs import DeQuantStub, QuantWrapper
from .qconfig import default_dynamic_qconfig, float16_dynamic_qconfig, float_qparams_dynamic_qconfig

Expand Down Expand Up @@ -86,7 +78,7 @@ def register_activation_post_process_hook(module):
'Expect activation_post_process attribut already attached to the module'
return module.register_forward_hook(_observer_forward_hook)

def add_observer_(module, qconfig_propagation_list=None, non_leaf_module_list=None, device=None):
def add_observer_(module, qconfig_propagation_list=None, non_leaf_module_list=None, device=None, custom_module_class_mapping=None):
r"""Add observer for the leaf child of the module.

This function insert observer module to all leaf child module that
Expand All @@ -103,6 +95,9 @@ def add_observer_(module, qconfig_propagation_list=None, non_leaf_module_list=No
if qconfig_propagation_list is None:
qconfig_propagation_list = get_qconfig_propagation_list()

if custom_module_class_mapping is None:
custom_module_class_mapping = {}

# respect device affinity when adding observers
if device is None:
devices = get_unique_devices_(module)
Expand Down Expand Up @@ -139,13 +134,12 @@ def insert_activation_post_process(m):
child.activation_post_process = get_activation_post_process(child.qconfig, device)
elif non_leaf_module_list is not None and type(child) in non_leaf_module_list:
insert_activation_post_process(child)
elif needs_observation(child) and is_custom_module_class(type(child)):
observed_child = get_observed_custom_module_class(type(child)).from_float(child)
mark_observed_custom_module(observed_child, type(child))
elif needs_observation(child) and type(child) in custom_module_class_mapping:
observed_child = custom_module_class_mapping[type(child)].from_float(child)
setattr(module, name, observed_child)
insert_activation_post_process(observed_child)
else:
add_observer_(child, qconfig_propagation_list, non_leaf_module_list, device)
add_observer_(child, qconfig_propagation_list, non_leaf_module_list, device, custom_module_class_mapping)

# Insert observers only for leaf nodes, note that this observer is for
# the output of the module, for input QuantStub will observe them
Expand Down Expand Up @@ -180,7 +174,8 @@ def add_quant_dequant(module):
return module

def prepare(model, inplace=False, allow_list=None,
observer_non_leaf_module_list=None):
observer_non_leaf_module_list=None,
prepare_custom_config_dict=None):
r"""Prepares a copy of the model for quantization calibration or quantization-aware training.

Quantization configuration should be assigned preemptively
Expand All @@ -194,8 +189,21 @@ def prepare(model, inplace=False, allow_list=None,
inplace: carry out model transformations in-place, the original module is mutated
allow_list: list of quantizable modules
observer_non_leaf_module_list: list of non-leaf modules we want to add observer
`prepare_custom_config_dict`: customization configuration dictionary for prepare function:
# user will manually define the corresponding observed
# module class which has a from_float class method that converts
# float custom module to observed custom module
prepare_custom_config_dict = {
"float_to_observed_custom_module_class": {
CustomModule: ObservedCustomModule
}
}
"""
torch._C._log_api_usage_once("quantization_api.quantize.prepare")
if prepare_custom_config_dict is None:
prepare_custom_config_dict = {}
custom_module_class_mapping = prepare_custom_config_dict.get("float_to_observed_custom_module_class", {})

if not inplace:
model = copy.deepcopy(model)

Expand All @@ -210,7 +218,9 @@ def prepare(model, inplace=False, allow_list=None,
"passed correct configuration through `qconfig_dict` or "
"by assigning the `.qconfig` attribute directly on submodules")

add_observer_(model, qconfig_propagation_list, observer_non_leaf_module_list)
add_observer_(
model, qconfig_propagation_list, observer_non_leaf_module_list,
custom_module_class_mapping=custom_module_class_mapping)
return model

def _remove_qconfig(module):
Expand Down Expand Up @@ -380,7 +390,9 @@ def quantize_qat(model, run_fn, run_args, inplace=False):
convert(model, inplace=True)
return model

def convert(module, mapping=None, inplace=False, remove_qconfig=True):
def convert(
module, mapping=None, inplace=False, remove_qconfig=True,
convert_custom_config_dict=None):
r"""Converts submodules in input module to a different module according to `mapping`
by calling `from_float` method on the target module class. And remove qconfig at the
end if remove_qconfig is set to True.
Expand All @@ -392,17 +404,29 @@ def convert(module, mapping=None, inplace=False, remove_qconfig=True):
Modules
inplace: carry out model transformations in-place, the original module
is mutated

`convert_custom_config_dict`: custom configuration dictionary for convert function:
convert_custom_config_dict = {
# user will manually define the corresponding quantized
# module class which has a from_observed class method that converts
# observed custom module to quantized custom module
"observed_to_quantized_custom_module_class": {
ObservedCustomModule: QuantizedCustomModule
}
}
"""
torch._C._log_api_usage_once("quantization_api.quantize.convert")
if not inplace:
module = copy.deepcopy(module)
_convert(module, mapping, inplace=True)
_convert(
module, mapping, inplace=True,
convert_custom_config_dict=convert_custom_config_dict)
if remove_qconfig:
_remove_qconfig(module)
return module

def _convert(module, mapping=None, inplace=False):
def _convert(
module, mapping=None, inplace=False,
convert_custom_config_dict=None):
r"""Converts submodules in input module to a different module according to `mapping`
by calling `from_float` method on the target module class

Expand All @@ -417,6 +441,10 @@ def _convert(module, mapping=None, inplace=False):
"""
if mapping is None:
mapping = get_static_quant_module_mappings()
if convert_custom_config_dict is None:
convert_custom_config_dict = {}
custom_module_class_mapping = convert_custom_config_dict.get("observed_to_quantized_custom_module_class", {})

if not inplace:
module = copy.deepcopy(module)
reassign = {}
Expand All @@ -440,16 +468,17 @@ def _convert(module, mapping=None, inplace=False):
# both swappable modules and observed custom modules are
# swapped as one unit
if type(mod) not in SWAPPABLE_MODULES and \
not is_observed_custom_module(mod):
_convert(mod, mapping, inplace=True)
reassign[name] = swap_module(mod, mapping)
type(mod) not in custom_module_class_mapping:
_convert(mod, mapping, True, # inplace
custom_module_class_mapping)
reassign[name] = swap_module(mod, mapping, custom_module_class_mapping)

for key, value in reassign.items():
module._modules[key] = value

return module

def swap_module(mod, mapping):
def swap_module(mod, mapping, custom_module_class_mapping):
r"""Swaps the module if it has a quantized counterpart and it has an
`observer` attached.

Expand All @@ -464,8 +493,8 @@ def swap_module(mod, mapping):
# Always replace dequantstub with dequantize
if hasattr(mod, 'qconfig') and mod.qconfig is not None or type(mod) == DeQuantStub:
swapped = False
if is_observed_custom_module(mod):
new_mod = get_quantized_custom_module_class(mod._FLOAT_MODULE).from_observed(mod)
if type(mod) in custom_module_class_mapping:
new_mod = custom_module_class_mapping[type(mod)].from_observed(mod)
swapped = True
elif type(mod) in mapping:
new_mod = mapping[type(mod)].from_float(mod)
Expand Down
21 changes: 10 additions & 11 deletions torch/testing/_internal/common_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,6 @@
default_qconfig, default_dynamic_qconfig, default_per_channel_qconfig, QConfig, default_observer, default_weight_observer, \
propagate_qconfig_, convert, get_default_qconfig, quantize_dynamic_jit, quantize_jit, float_qparams_dynamic_qconfig, \
get_default_qat_qconfig, PerChannelMinMaxObserver, default_dynamic_quant_observer, QConfigDynamic
from torch.quantization import (
is_custom_module_class,
is_observed_custom_module,
)
from torch.quantization.quantization_mappings import (
get_dynamic_quant_module_mappings,
get_qconfig_propagation_list,
Expand Down Expand Up @@ -341,12 +337,15 @@ def checkHasPrepModules(self, module):
self.assertTrue(hasattr(module, 'quant'))
self.assertTrue(hasattr(module, 'dequant'))

def checkObservers(self, module, propagate_qconfig_list=None):
def checkObservers(self, module, propagate_qconfig_list=None, prepare_custom_config_dict=None):
r"""Checks the module or module's leaf descendants
have observers in preperation for quantization
"""
if propagate_qconfig_list is None:
propagate_qconfig_list = get_qconfig_propagation_list()
if prepare_custom_config_dict is None:
prepare_custom_config_dict = {}
float_to_observed_module_class_mapping = prepare_custom_config_dict.get("float_to_observed_custom_module_class", {})

# check if a module is a leaf module, ignoring activation_post_process attribute
def is_leaf_module(module):
Expand All @@ -356,18 +355,18 @@ def is_leaf_module(module):
submodule_name_count += 1
return submodule_name_count == 0

if (hasattr(module, 'qconfig') and module.qconfig is not None and
is_leaf_module(module) and not isinstance(module, torch.nn.Sequential)
and type(module) in propagate_qconfig_list) or \
is_custom_module_class(type(module)):
if hasattr(module, 'qconfig') and module.qconfig is not None and \
((is_leaf_module(module) and not isinstance(module, torch.nn.Sequential)
and type(module) in propagate_qconfig_list) or
type(module) in float_to_observed_module_class_mapping.keys()):
self.assertTrue(hasattr(module, 'activation_post_process'),
'module: ' + str(type(module)) + ' do not have observer')
# we don't need to check observers for child modules of the
# qat modules
if type(module) not in get_qat_module_mappings().values() and \
not is_observed_custom_module(module):
type(module) not in float_to_observed_module_class_mapping.values():
for child in module.children():
self.checkObservers(child)
self.checkObservers(child, propagate_qconfig_list, prepare_custom_config_dict)

def checkQuantDequant(self, mod):
r"""Checks that mod has nn.Quantize and
Expand Down