diff --git a/torch/quantization/fuse_modules.py b/torch/quantization/fuse_modules.py index d8f06e538726..e6669519b964 100644 --- a/torch/quantization/fuse_modules.py +++ b/torch/quantization/fuse_modules.py @@ -28,7 +28,7 @@ def _set_module(model, submodule_key, module): setattr(cur_mod, tokens[-1], module) -def fuse_known_modules(mod_list): +def fuse_known_modules(mod_list, additional_fuser_method_mapping=None): r"""Returns a list of modules that fuses the operations specified in the input module list. @@ -41,7 +41,7 @@ def fuse_known_modules(mod_list): the fused operation. The rest of the elements are set to nn.Identity() """ types = tuple(type(m) for m in mod_list) - fuser_method = get_fuser_method(types) + fuser_method = get_fuser_method(types, additional_fuser_method_mapping) if fuser_method is None: raise NotImplementedError("Cannot fuse modules: {}".format(types)) new_mod : List[Optional[nn.Module]] = [None] * len(mod_list) @@ -64,20 +64,23 @@ def fuse_known_modules(mod_list): return new_mod -def _fuse_modules(model, modules_to_fuse, fuser_func=fuse_known_modules): - +# TODO: remove fuser_func? +def _fuse_modules(model, modules_to_fuse, fuser_func=fuse_known_modules, fuse_custom_config_dict=None): + if fuse_custom_config_dict is None: + fuse_custom_config_dict = {} + additional_fuser_method_mapping = fuse_custom_config_dict.get("additional_fuser_method_mapping", {}) mod_list = [] for item in modules_to_fuse: mod_list.append(_get_module(model, item)) # Fuse list of modules - new_mod_list = fuser_func(mod_list) + new_mod_list = fuser_func(mod_list, additional_fuser_method_mapping) # Replace original module list with fused module list for i, item in enumerate(modules_to_fuse): _set_module(model, item, new_mod_list[i]) -def fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_modules): +def fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_modules, fuse_custom_config_dict=None): r"""Fuses a list of modules into a single module Fuses only the following sequence of modules: @@ -101,6 +104,14 @@ def fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_mo of the same length. For example, fuser_func([convModule, BNModule]) returns the list [ConvBNModule, nn.Identity()] Defaults to torch.quantization.fuse_known_modules + `fuse_custom_config_dict`: custom configuration for fusion: + Example: + fuse_custom_config_dict = { + # Additional fuser_method mapping + "additional_fuser_method_mapping": { + (torch.nn.Conv2d, torch.nn.BatchNorm2d): fuse_conv_bn + }, + } Returns: model with fused modules. A new copy is created if inplace=True. @@ -124,9 +135,9 @@ def fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_mo if all(isinstance(module_element, str) for module_element in modules_to_fuse): # Handle case of modules_to_fuse being a list - _fuse_modules(model, modules_to_fuse, fuser_func) + _fuse_modules(model, modules_to_fuse, fuser_func, fuse_custom_config_dict) else: # Handle case of modules_to_fuse being a list of lists for module_list in modules_to_fuse: - _fuse_modules(model, module_list, fuser_func) + _fuse_modules(model, module_list, fuser_func, fuse_custom_config_dict) return model diff --git a/torch/quantization/quantize.py b/torch/quantization/quantize.py index 82295e72d5e5..9aa52d373ff6 100644 --- a/torch/quantization/quantize.py +++ b/torch/quantization/quantize.py @@ -212,6 +212,7 @@ def prepare(model, inplace=False, allow_list=None, if not inplace: model = copy.deepcopy(model) + # TODO: remove allow_list qconfig_propagation_list = allow_list if qconfig_propagation_list is None: qconfig_propagation_list = get_default_qconfig_propagation_list() @@ -365,6 +366,7 @@ def prepare_qat(model, mapping=None, inplace=False): torch._C._log_api_usage_once("quantization_api.quantize.prepare_qat") if mapping is None: mapping = get_default_qat_module_mappings() + if not inplace: model = copy.deepcopy(model) diff --git a/torch/quantization/quantize_fx.py b/torch/quantization/quantize_fx.py index 3f869b3714f5..37fec5b117e8 100644 --- a/torch/quantization/quantize_fx.py +++ b/torch/quantization/quantize_fx.py @@ -177,12 +177,12 @@ def prepare_fx(model, qconfig_dict, inplace=False, prepare_custom_config_dict=No # Additional fuser_method mapping "additional_fuser_method_mapping": { - (ModuleClass1, ModuleClass2): fuse_module1_module2 + (torch.nn.Conv2d, torch.nn.BatchNorm2d): fuse_conv_bn }, # Additioanl module mapping for qat "additional_qat_module_mapping": { - FloatModule: QATModule + torch.nn.intrinsic.ConvBn2d: torch.nn.qat.ConvBn2d }, # Additional fusion patterns