Skip to content

Commit

Permalink
[quant][graphmode][fx] Add support for additional_{fusion/quant}_pattern
Browse files Browse the repository at this point in the history
Summary:
Allow user to provide additional fusion/quant patterns for fx graph mode

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 1dae0b08c426e83393da7704d7ca2b8cd72ce9d0
Pull Request resolved: #46346
  • Loading branch information
jerryzh168 committed Oct 21, 2020
1 parent 8b174ed commit cf9e804
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 2 deletions.
5 changes: 4 additions & 1 deletion torch/quantization/fx/fuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ def fuse(self, model, inplace=False, fuse_custom_config_dict=None):
input_graph = model.graph
self.modules = dict(input_root.named_modules())

fusion_patterns = get_default_fusion_patterns()
additional_fusion_patterns = fuse_custom_config_dict.get("additional_quant_pattern", {})
fusion_patterns = get_default_fusion_patterns().copy()
for k, v in additional_fusion_patterns.items():
fusion_patterns[k] = v
# find fusion
fusion_pairs = self._find_matches(input_root, input_graph, fusion_patterns)
self.fused_graph = Graph()
Expand Down
5 changes: 4 additions & 1 deletion torch/quantization/fx/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,10 @@ def _prepare(self, model, qconfig_dict, inplace, prepare_custom_config_dict, is_
prepare_custom_config_dict = {}
if not inplace:
model = copy.deepcopy(model)
self.patterns = get_default_quant_patterns()
additional_quant_patterns = prepare_custom_config_dict.get("additional_quant_pattern", {})
self.patterns = get_default_quant_patterns.copy()
for k, v in additional_quant_patterns.items():
self.patterns[k] = v

flattened_qconfig_dict = get_flattened_qconfig_dict(qconfig_dict)
# TODO: support regex as well
Expand Down
11 changes: 11 additions & 0 deletions torch/quantization/quantize_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,17 @@ def prepare_fx(model, qconfig_dict, inplace=False, prepare_custom_config_dict=No
"additional_qat_module_mapping": {
FloatModule: QATModule
},
# Additional fusion patterns
"additional_fusion_pattern": {
(torch.nn.BatchNorm2d, torch.nn.Conv2d): ConvReluFusionhandler
},
# Additional quantization patterns
"additional_quant_pattern": {
torch.nn.Conv2d: ConvReluQuantizeHandler,
(torch.nn.ReLU, torch.nn.Conv2d): ConvReluQuantizeHandler,
}
}
Expand Down

0 comments on commit cf9e804

Please sign in to comment.