Skip to content

Commit

Permalink
[quant][graphmode][fx] Add support for additional_fuse_method_mapping
Browse files Browse the repository at this point in the history
Summary:
Allow user to add more fusion mappings

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
  • Loading branch information
jerryzh168 committed Oct 14, 2020
1 parent 38b3909 commit 453a249
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 13 deletions.
10 changes: 7 additions & 3 deletions torch/quantization/fuser_method_mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,13 @@ def fuse_conv_bn_relu(conv, bn, relu):
(nn.BatchNorm3d, nn.ReLU): nni.BNReLU3d,
}

# TODO: remove
def get_fuser_method(op_list):
def get_fuser_method(op_list, additional_fuser_method_mapping=None):
''' Get fuser method for the given list of module types,
return None if fuser method does not exist
'''
return DEFAULT_OP_LIST_TO_FUSER_METHOD.get(op_list, None)
if additional_fuser_method_mapping is None:
additional_fuser_method_mapping = {}
all_mappings = dict(DEFAULT_OP_LIST_TO_FUSER_METHOD, **additional_fuser_method_mapping)
fuser_method = all_mappings.get(op_list, None)
assert fuser_method is not None, "did not find fuser method for: {} ".format(op_list)
return fuser_method
5 changes: 4 additions & 1 deletion torch/quantization/fx/fuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@

import copy
class Fuser:
def fuse(self, model, inplace=False):
def fuse(self, model, inplace=False, fuse_custom_config_dict=None):
if fuse_custom_config_dict is None:
fuse_custom_config_dict = {}
if not inplace:
model = copy.deepcopy(model)

input_root = model
input_graph = model.graph
self.modules = dict(input_root.named_modules())
Expand Down
14 changes: 10 additions & 4 deletions torch/quantization/fx/fusion_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ def __init__(self, quantizer, node):
self.conv_node = node
self.conv = quantizer.modules[self.conv_node.target]

def fuse(self, quantizer, load_arg):
def fuse(self, quantizer, load_arg, fuse_custom_config_dict=None):
if fuse_custom_config_dict is None:
fuse_custom_config_dict = {}
additional_fuser_method_mapping = fuse_custom_config_dict.get("additional_fuser_method_mapping", {})
op_list = []
if self.relu_node is not None:
# since relu can be used multiple times, we'll need to create a relu module for each match
Expand All @@ -60,7 +63,7 @@ def fuse(self, quantizer, load_arg):
op_list.reverse()
op_type_list = tuple(type(m) for m in op_list)
conv_parent_name, conv_name = _parent_name(self.conv_node.target)
fuser_method = get_fuser_method(op_type_list)
fuser_method = get_fuser_method(op_type_list, additional_fuser_method_mapping)
if fuser_method is None:
raise NotImplementedError("Cannot fuse modules: {}".format(types))
setattr(quantizer.modules[conv_parent_name], conv_name, fuser_method(*op_list))
Expand Down Expand Up @@ -89,7 +92,10 @@ def __init__(self, quantizer, node):
self.module_node = node
self.module = quantizer.modules[self.module_node.target]

def fuse(self, quantizer, load_arg):
def fuse(self, quantizer, load_arg, fuse_custom_config_dict=None):
if fuse_custom_config_dict is None:
fuse_custom_config_dict = {}
additional_fuser_method_mapping = fuse_custom_config_dict.get("additional_fuser_method_mapping", {})
op_list = []
# since relu can be used multiple times, we'll need to create a relu module for each match
if self.relu_node.op == 'call_module':
Expand All @@ -104,6 +110,6 @@ def fuse(self, quantizer, load_arg):
op_list.reverse()
op_type_list = tuple(type(m) for m in op_list)
module_parent_name, module_name = _parent_name(self.module_node.target)
fuser_method = get_fuser_method(op_type_list)
fuser_method = get_fuser_method(op_type_list, additional_fuser_method_mapping)
setattr(quantizer.modules[module_parent_name], module_name, fuser_method(*op_list))
return quantizer.fused_graph.node_copy(self.module_node, load_arg)
21 changes: 16 additions & 5 deletions torch/quantization/quantize_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@ def _check_is_graph_module(model):
'Got type:' + str(type(model)) + ' Please make ' +
'sure to follow the tutorials.')

def _fuse_fx(graph_module, inplace=False):
def _fuse_fx(graph_module, inplace=False, fuse_custom_config_dict=None):
r""" Internal helper function to fuse modules in preparation for quantization
Args:
graph_module: GraphModule object from symbolic tracing (torch.fx.symbolic_trace)
"""
_check_is_graph_module(graph_module)
fuser = Fuser()
return fuser.fuse(graph_module, inplace)
return fuser.fuse(graph_module, inplace, fuse_custom_config_dict)

class CustomTracer(Tracer):
def __init__(self, skipped_module_names, skipped_module_classes):
Expand Down Expand Up @@ -63,7 +63,7 @@ def _prepare_fx(model, qconfig_dict, inplace, prepare_custom_config_dict=None, i
skipped_module_classes += custom_module_classes
tracer = CustomTracer(skipped_module_names, skipped_module_classes)
graph_module = GraphModule(model, tracer.trace(model))
graph_module = _fuse_fx(graph_module, inplace)
graph_module = _fuse_fx(graph_module, inplace, prepare_custom_config_dict)
quantizer = Quantizer()
return quantizer.prepare(
graph_module,
Expand Down Expand Up @@ -91,12 +91,18 @@ def _prepare_standalone_module_fx(model, qconfig_dict, inplace=False, prepare_cu
return _prepare_fx(model, qconfig_dict, inplace, prepare_custom_config_dict, is_standalone_module=True)


def fuse_fx(model, inplace=False):
def fuse_fx(model, inplace=False, fuse_custom_config_dict=None):
r""" Fuse modules like conv+bn, conv+bn+relu etc, model must be in eval mode.
Fusion rules are defined in torch.quantization.fx.fusion_pattern.py
Args:
`model`: a torch.nn.Module model
`inplace`: flag for whether we fuse modules inplace or out of place
`fuse_custom_config_dict`: Dictionary for custom configurations for fuse_fx, e.g.
fuse_custom_config_dict = {
"additional_fuser_method_mapping": {
(Module1, Module2): fuse_module1_module2
}
}
Example:
```python
Expand All @@ -108,7 +114,7 @@ def fuse_fx(model, inplace=False):
torch._C._log_api_usage_once("quantization_api.quantize_fx.fuse_fx")
assert not model.training, 'fuse_fx only works on models in eval mode'
graph_module = torch.fx.symbolic_trace(model)
return _fuse_fx(graph_module, inplace)
return _fuse_fx(graph_module, inplace, fuse_custom_config_dict)

def prepare_fx(model, qconfig_dict, inplace=False, prepare_custom_config_dict=None):
r""" Prepare a model for post training static quantization
Expand Down Expand Up @@ -170,6 +176,11 @@ def prepare_fx(model, qconfig_dict, inplace=False, prepare_custom_config_dict=No
NonTraceableModule
],
# Additional fuser_method mapping
"additional_fuser_method_mapping": {
(ModuleClass1, ModuleClass2): fuse_module1_module2
},
# Additioanl module mapping for qat
"additional_qat_module_mapping": {
FloatModule: QATModule
Expand Down

0 comments on commit 453a249

Please sign in to comment.