Skip to content

Commit

Permalink
[quant][graphmode][fx] Add support for additional_{fusion/quant}_pattern
Browse files Browse the repository at this point in the history
Summary:
Allow user to provide additional fusion/quant patterns for fx graph mode

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
  • Loading branch information
jerryzh168 committed Oct 14, 2020
1 parent 453a249 commit 0f78cca
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 2 deletions.
3 changes: 2 additions & 1 deletion torch/quantization/fx/fuse.py
Expand Up @@ -24,7 +24,8 @@ def fuse(self, model, inplace=False, fuse_custom_config_dict=None):
input_graph = model.graph
self.modules = dict(input_root.named_modules())

fusion_patterns = get_default_fusion_patterns()
additional_fusion_patterns = fuse_custom_config_dict.get("additional_quant_pattern", {})
fusion_patterns = dict(get_default_fusion_patterns(), **additional_fusion_patterns)
# find fusion
fusion_pairs = self._find_matches(input_root, input_graph, fusion_patterns)
self.fused_graph = Graph()
Expand Down
3 changes: 2 additions & 1 deletion torch/quantization/fx/quantize.py
Expand Up @@ -325,7 +325,8 @@ def _prepare(self, model, qconfig_dict, inplace, prepare_custom_config_dict, is_
prepare_custom_config_dict = {}
if not inplace:
model = copy.deepcopy(model)
self.patterns = get_default_quant_patterns()
additional_quant_patterns = prepare_custom_config_dict.get("additional_quant_pattern", {})
self.patterns = dict(get_default_quant_patterns, **additional_quant_patterns)

flattened_qconfig_dict = get_flattened_qconfig_dict(qconfig_dict)
# TODO: support regex as well
Expand Down
11 changes: 11 additions & 0 deletions torch/quantization/quantize_fx.py
Expand Up @@ -185,6 +185,17 @@ def prepare_fx(model, qconfig_dict, inplace=False, prepare_custom_config_dict=No
"additional_qat_module_mapping": {
FloatModule: QATModule
},
# Additional fusion patterns
"additioanl_fusion_pattern": {
(torch.nn.BatchNorm2d, torch.nn.Conv2d): ConvReluFusionhandler
},
# Additional quantization patterns
"additional_quant_pattern": {
torch.nn.Conv2d: ConvReluQuantizeHandler,
(torch.nn.ReLU, torch.nn.Conv2d): ConvReluQuantizeHandler,
}
}
Expand Down

0 comments on commit 0f78cca

Please sign in to comment.