diff --git a/torch/quantization/fx/fuse.py b/torch/quantization/fx/fuse.py index 574c35be33ce..989f8937dd7c 100644 --- a/torch/quantization/fx/fuse.py +++ b/torch/quantization/fx/fuse.py @@ -24,7 +24,10 @@ def fuse(self, model, inplace=False, fuse_custom_config_dict=None): input_graph = model.graph self.modules = dict(input_root.named_modules()) - fusion_patterns = get_default_fusion_patterns() + additional_fusion_patterns = fuse_custom_config_dict.get("additional_quant_pattern", {}) + fusion_patterns = get_default_fusion_patterns().copy() + for k, v in additional_fusion_patterns.items(): + fusion_patterns[k] = v # find fusion fusion_pairs = self._find_matches(input_root, input_graph, fusion_patterns) self.fused_graph = Graph() diff --git a/torch/quantization/fx/quantize.py b/torch/quantization/fx/quantize.py index f700fa3ece47..3eebef4ff10a 100644 --- a/torch/quantization/fx/quantize.py +++ b/torch/quantization/fx/quantize.py @@ -327,7 +327,10 @@ def _prepare(self, model, qconfig_dict, inplace, prepare_custom_config_dict, is_ prepare_custom_config_dict = {} if not inplace: model = copy.deepcopy(model) - self.patterns = get_default_quant_patterns() + additional_quant_patterns = prepare_custom_config_dict.get("additional_quant_pattern", {}) + self.patterns = get_default_quant_patterns().copy() + for k, v in additional_quant_patterns.items(): + self.patterns[k] = v flattened_qconfig_dict = get_flattened_qconfig_dict(qconfig_dict) # TODO: support regex as well diff --git a/torch/quantization/quantize_fx.py b/torch/quantization/quantize_fx.py index 7563abacca1f..38aa01cb6c89 100644 --- a/torch/quantization/quantize_fx.py +++ b/torch/quantization/quantize_fx.py @@ -201,6 +201,17 @@ def prepare_fx(model, qconfig_dict, inplace=False, prepare_custom_config_dict=No "additional_qat_module_mapping": { FloatModule: QATModule }, + + # Additional fusion patterns + "additional_fusion_pattern": { + (torch.nn.BatchNorm2d, torch.nn.Conv2d): ConvReluFusionhandler + }, + + # Additional quantization patterns + "additional_quant_pattern": { + torch.nn.Conv2d: ConvReluQuantizeHandler, + (torch.nn.ReLU, torch.nn.Conv2d): ConvReluQuantizeHandler, + } }