Skip to content

Commit

Permalink
[quant][eagermode] Add additional_fuser_method_mapping to config
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: c8663123d3b3ddbcc1a210bcb9dcba970086cd70
Pull Request resolved: #46355
  • Loading branch information
jerryzh168 committed Oct 21, 2020
1 parent cf9e804 commit c675400
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 10 deletions.
27 changes: 19 additions & 8 deletions torch/quantization/fuse_modules.py
Expand Up @@ -28,7 +28,7 @@ def _set_module(model, submodule_key, module):

setattr(cur_mod, tokens[-1], module)

def fuse_known_modules(mod_list):
def fuse_known_modules(mod_list, additional_fuser_method_mapping=None):
r"""Returns a list of modules that fuses the operations specified
in the input module list.
Expand All @@ -41,7 +41,7 @@ def fuse_known_modules(mod_list):
the fused operation. The rest of the elements are set to nn.Identity()
"""
types = tuple(type(m) for m in mod_list)
fuser_method = get_fuser_method(types)
fuser_method = get_fuser_method(types, additional_fuser_method_mapping)
if fuser_method is None:
raise NotImplementedError("Cannot fuse modules: {}".format(types))
new_mod : List[Optional[nn.Module]] = [None] * len(mod_list)
Expand All @@ -64,20 +64,23 @@ def fuse_known_modules(mod_list):

return new_mod

def _fuse_modules(model, modules_to_fuse, fuser_func=fuse_known_modules):

# TODO: remove fuser_func?
def _fuse_modules(model, modules_to_fuse, fuser_func=fuse_known_modules, fuse_custom_config_dict=None):
if fuse_custom_config_dict is None:
fuse_custom_config_dict = {}
additional_fuser_method_mapping = fuse_custom_config_dict.get("additional_fuser_method_mapping", {})
mod_list = []
for item in modules_to_fuse:
mod_list.append(_get_module(model, item))

# Fuse list of modules
new_mod_list = fuser_func(mod_list)
new_mod_list = fuser_func(mod_list, additional_fuser_method_mapping)

# Replace original module list with fused module list
for i, item in enumerate(modules_to_fuse):
_set_module(model, item, new_mod_list[i])

def fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_modules):
def fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_modules, fuse_custom_config_dict=None):
r"""Fuses a list of modules into a single module
Fuses only the following sequence of modules:
Expand All @@ -101,6 +104,14 @@ def fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_mo
of the same length. For example,
fuser_func([convModule, BNModule]) returns the list [ConvBNModule, nn.Identity()]
Defaults to torch.quantization.fuse_known_modules
`fuse_custom_config_dict`: custom configuration for fusion:
Example:
fuse_custom_config_dict = {
# Additional fuser_method mapping
"additional_fuser_method_mapping": {
(torch.nn.Conv2d, torch.nn.BatchNorm2d): fuse_conv_bn
},
}
Returns:
model with fused modules. A new copy is created if inplace=True.
Expand All @@ -124,9 +135,9 @@ def fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_mo

if all(isinstance(module_element, str) for module_element in modules_to_fuse):
# Handle case of modules_to_fuse being a list
_fuse_modules(model, modules_to_fuse, fuser_func)
_fuse_modules(model, modules_to_fuse, fuser_func, fuse_custom_config_dict)
else:
# Handle case of modules_to_fuse being a list of lists
for module_list in modules_to_fuse:
_fuse_modules(model, module_list, fuser_func)
_fuse_modules(model, module_list, fuser_func, fuse_custom_config_dict)
return model
2 changes: 2 additions & 0 deletions torch/quantization/quantize.py
Expand Up @@ -212,6 +212,7 @@ def prepare(model, inplace=False, allow_list=None,
if not inplace:
model = copy.deepcopy(model)

# TODO: remove allow_list
qconfig_propagation_list = allow_list
if qconfig_propagation_list is None:
qconfig_propagation_list = get_default_qconfig_propagation_list()
Expand Down Expand Up @@ -365,6 +366,7 @@ def prepare_qat(model, mapping=None, inplace=False):
torch._C._log_api_usage_once("quantization_api.quantize.prepare_qat")
if mapping is None:
mapping = get_default_qat_module_mappings()

if not inplace:
model = copy.deepcopy(model)

Expand Down
4 changes: 2 additions & 2 deletions torch/quantization/quantize_fx.py
Expand Up @@ -177,12 +177,12 @@ def prepare_fx(model, qconfig_dict, inplace=False, prepare_custom_config_dict=No
# Additional fuser_method mapping
"additional_fuser_method_mapping": {
(ModuleClass1, ModuleClass2): fuse_module1_module2
(torch.nn.Conv2d, torch.nn.BatchNorm2d): fuse_conv_bn
},
# Additioanl module mapping for qat
"additional_qat_module_mapping": {
FloatModule: QATModule
torch.nn.intrinsic.ConvBn2d: torch.nn.qat.ConvBn2d
},
# Additional fusion patterns
Expand Down

0 comments on commit c675400

Please sign in to comment.