Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[quant][graphmode][fx] Add support for additional_fuse_method_mapping #46345

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 9 additions & 3 deletions torch/quantization/fuser_method_mappings.py
Expand Up @@ -87,9 +87,15 @@ def fuse_conv_bn_relu(conv, bn, relu):
(nn.BatchNorm3d, nn.ReLU): nni.BNReLU3d,
}

# TODO: remove
def get_fuser_method(op_list):
def get_fuser_method(op_list, additional_fuser_method_mapping=None):
''' Get fuser method for the given list of module types,
return None if fuser method does not exist
'''
return DEFAULT_OP_LIST_TO_FUSER_METHOD.get(op_list, None)
if additional_fuser_method_mapping is None:
additional_fuser_method_mapping = {}
all_mappings = DEFAULT_OP_LIST_TO_FUSER_METHOD.copy()
for k, v in additional_fuser_method_mapping:
all_mappings[k] = v
fuser_method = all_mappings.get(op_list, None)
assert fuser_method is not None, "did not find fuser method for: {} ".format(op_list)
return fuser_method
5 changes: 4 additions & 1 deletion torch/quantization/fx/fuse.py
Expand Up @@ -14,9 +14,12 @@

import copy
class Fuser:
def fuse(self, model, inplace=False):
def fuse(self, model, inplace=False, fuse_custom_config_dict=None):
if fuse_custom_config_dict is None:
fuse_custom_config_dict = {}
if not inplace:
model = copy.deepcopy(model)

input_root = model
input_graph = model.graph
self.modules = dict(input_root.named_modules())
Expand Down
14 changes: 10 additions & 4 deletions torch/quantization/fx/fusion_patterns.py
Expand Up @@ -36,7 +36,10 @@ def __init__(self, quantizer, node):
self.conv_node = node
self.conv = quantizer.modules[self.conv_node.target]

def fuse(self, quantizer, load_arg):
def fuse(self, quantizer, load_arg, fuse_custom_config_dict=None):
if fuse_custom_config_dict is None:
fuse_custom_config_dict = {}
additional_fuser_method_mapping = fuse_custom_config_dict.get("additional_fuser_method_mapping", {})
op_list = []
if self.relu_node is not None:
# since relu can be used multiple times, we'll need to create a relu module for each match
Expand All @@ -60,7 +63,7 @@ def fuse(self, quantizer, load_arg):
op_list.reverse()
op_type_list = tuple(type(m) for m in op_list)
conv_parent_name, conv_name = _parent_name(self.conv_node.target)
fuser_method = get_fuser_method(op_type_list)
fuser_method = get_fuser_method(op_type_list, additional_fuser_method_mapping)
if fuser_method is None:
raise NotImplementedError("Cannot fuse modules: {}".format(types))
setattr(quantizer.modules[conv_parent_name], conv_name, fuser_method(*op_list))
Expand Down Expand Up @@ -89,7 +92,10 @@ def __init__(self, quantizer, node):
self.module_node = node
self.module = quantizer.modules[self.module_node.target]

def fuse(self, quantizer, load_arg):
def fuse(self, quantizer, load_arg, fuse_custom_config_dict=None):
if fuse_custom_config_dict is None:
fuse_custom_config_dict = {}
additional_fuser_method_mapping = fuse_custom_config_dict.get("additional_fuser_method_mapping", {})
op_list = []
# since relu can be used multiple times, we'll need to create a relu module for each match
if self.relu_node.op == 'call_module':
Expand All @@ -104,6 +110,6 @@ def fuse(self, quantizer, load_arg):
op_list.reverse()
op_type_list = tuple(type(m) for m in op_list)
module_parent_name, module_name = _parent_name(self.module_node.target)
fuser_method = get_fuser_method(op_type_list)
fuser_method = get_fuser_method(op_type_list, additional_fuser_method_mapping)
setattr(quantizer.modules[module_parent_name], module_name, fuser_method(*op_list))
return quantizer.fused_graph.node_copy(self.module_node, load_arg)
21 changes: 16 additions & 5 deletions torch/quantization/quantize_fx.py
Expand Up @@ -26,15 +26,15 @@ def _swap_ff_with_fxff(model):
del model._modules[name]
model._modules[name] = torch.nn.quantized.FXFloatFunctional()

def _fuse_fx(graph_module, inplace=False):
def _fuse_fx(graph_module, inplace=False, fuse_custom_config_dict=None):
r""" Internal helper function to fuse modules in preparation for quantization

Args:
graph_module: GraphModule object from symbolic tracing (torch.fx.symbolic_trace)
"""
_check_is_graph_module(graph_module)
fuser = Fuser()
return fuser.fuse(graph_module, inplace)
return fuser.fuse(graph_module, inplace, fuse_custom_config_dict)

class CustomTracer(Tracer):
def __init__(self, skipped_module_names, skipped_module_classes):
Expand Down Expand Up @@ -79,7 +79,7 @@ def _prepare_fx(model, qconfig_dict, inplace, prepare_custom_config_dict=None, i
skipped_module_classes += custom_module_classes
tracer = CustomTracer(skipped_module_names, skipped_module_classes)
graph_module = GraphModule(model, tracer.trace(model))
graph_module = _fuse_fx(graph_module, inplace)
graph_module = _fuse_fx(graph_module, inplace, prepare_custom_config_dict)
quantizer = Quantizer()
return quantizer.prepare(
graph_module,
Expand Down Expand Up @@ -107,12 +107,18 @@ def _prepare_standalone_module_fx(model, qconfig_dict, inplace=False, prepare_cu
return _prepare_fx(model, qconfig_dict, inplace, prepare_custom_config_dict, is_standalone_module=True)


def fuse_fx(model, inplace=False):
def fuse_fx(model, inplace=False, fuse_custom_config_dict=None):
r""" Fuse modules like conv+bn, conv+bn+relu etc, model must be in eval mode.
Fusion rules are defined in torch.quantization.fx.fusion_pattern.py
Args:
`model`: a torch.nn.Module model
`inplace`: flag for whether we fuse modules inplace or out of place
`fuse_custom_config_dict`: Dictionary for custom configurations for fuse_fx, e.g.
fuse_custom_config_dict = {
"additional_fuser_method_mapping": {
(Module1, Module2): fuse_module1_module2
}
}

Example:
```python
Expand All @@ -124,7 +130,7 @@ def fuse_fx(model, inplace=False):
torch._C._log_api_usage_once("quantization_api.quantize_fx.fuse_fx")
assert not model.training, 'fuse_fx only works on models in eval mode'
graph_module = torch.fx.symbolic_trace(model)
return _fuse_fx(graph_module, inplace)
return _fuse_fx(graph_module, inplace, fuse_custom_config_dict)

def prepare_fx(model, qconfig_dict, inplace=False, prepare_custom_config_dict=None):
r""" Prepare a model for post training static quantization
Expand Down Expand Up @@ -186,6 +192,11 @@ def prepare_fx(model, qconfig_dict, inplace=False, prepare_custom_config_dict=No
NonTraceableModule
],

# Additional fuser_method mapping
"additional_fuser_method_mapping": {
(ModuleClass1, ModuleClass2): fuse_module1_module2
},

# Additioanl module mapping for qat
"additional_qat_module_mapping": {
FloatModule: QATModule
Expand Down