From 05d4721f8461ed634203eaafb3ea72135abf04c7 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 27 Oct 2020 10:47:02 -0700 Subject: [PATCH] [quant][graphmode][fx] custom_module support static/dynamic/weight_only quant Summary: Previously we only support static quant, this PR added support for other types of quantization. Note qat is actually orthogonal to these quant types, this is referring to the convert step where we convert the observed module to a quantized module. for qat, user will provide a CustomModule -> FakeQuantizedCustomModule in prepare_custom_config_dict and FakeQuantizedCustomModule -> static/dynamic/weight_only quantized CustomModule in convert_custom_config_dict. Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 849441a7a8eec52c6510889df3bc1f8597a84546 Pull Request resolved: https://github.com/pytorch/pytorch/pull/46786 --- test/quantization/test_quantize_fx.py | 123 +++++++++++------- torch/quantization/__init__.py | 2 +- .../quantization/fx/quantization_patterns.py | 16 ++- torch/quantization/fx/quantize.py | 22 ++-- torch/quantization/fx/utils.py | 70 ++++++++++ torch/quantization/quant_type.py | 12 +- torch/quantization/quantize_fx.py | 23 +++- 7 files changed, 199 insertions(+), 69 deletions(-) diff --git a/test/quantization/test_quantize_fx.py b/test/quantization/test_quantize_fx.py index 4e352a80daa3..be1e5a6cb6e8 100644 --- a/test/quantization/test_quantize_fx.py +++ b/test/quantization/test_quantize_fx.py @@ -9,6 +9,7 @@ # graph mode quantization based on fx from torch.quantization import ( QuantType, + quant_type_to_str, prepare_fx, convert_fx, prepare_qat_fx, @@ -632,104 +633,126 @@ def test_custom_module_class(self): class CustomModule(torch.nn.Module): def __init__(self): super().__init__() - self.conv = torch.nn.Conv2d(1, 1, 1) + self.linear = torch.nn.Linear(3, 3) def forward(self, x): - return self.conv(x) + return self.linear(x) class ObservedCustomModule(torch.nn.Module): - def __init__(self, conv): + def __init__(self, linear): super().__init__() - self.conv = conv + self.linear = linear def forward(self, x): - return self.conv(x) + return self.linear(x) @classmethod def from_float(cls, float_module): assert hasattr(float_module, 'qconfig') - observed = cls(float_module.conv) + observed = cls(float_module.linear) observed.qconfig = float_module.qconfig return observed - class QuantizedCustomModule(torch.nn.Module): - def __init__(self, conv): + class StaticQuantCustomModule(torch.nn.Module): + def __init__(self, linear): super().__init__() - self.conv = conv + self.linear = linear def forward(self, x): - return self.conv(x) + return self.linear(x) @classmethod def from_observed(cls, observed_module): assert hasattr(observed_module, 'qconfig') assert hasattr(observed_module, 'activation_post_process') - observed_module.conv.activation_post_process = \ + observed_module.linear.activation_post_process = \ observed_module.activation_post_process - quantized = cls(nnq.Conv2d.from_float(observed_module.conv)) + quantized = cls(nnq.Linear.from_float(observed_module.linear)) return quantized - class DynamicallyQuantizedCustomModule(torch.nn.Module): - def __init__(self, conv): + class DynamicQuantCustomModule(torch.nn.Module): + def __init__(self, linear): super().__init__() - self.conv = conv + self.linear = linear def forward(self, x): - return self.conv(x) + return self.linear(x) @classmethod def from_observed(cls, observed_module): assert hasattr(observed_module, 'qconfig') - assert hasattr(observed_module, 'activation_post_process') - quantized = cls(nnqd.Conv2d.from_float(observed_module.conv)) + quantized = cls(nnqd.Linear.from_float(observed_module.linear)) return quantized class M(torch.nn.Module): def __init__(self): super().__init__() - self.conv = torch.nn.Conv2d(1, 1, 1) + self.linear = torch.nn.Linear(3, 3) self.custom = CustomModule() def forward(self, x): - x = self.conv(x) + x = self.linear(x) x = self.custom(x) return x class RefM(torch.nn.Module): def __init__(self): super().__init__() - self.conv1 = torch.nn.Conv2d(1, 1, 1) - self.conv2 = torch.nn.Conv2d(1, 1, 1) + self.linear1 = torch.nn.Linear(3, 3) + self.linear2 = torch.nn.Linear(3, 3) def forward(self, x): - x = self.conv1(x) - x = self.conv2(x) + x = self.linear1(x) + x = self.linear2(x) return x - data = torch.randn(1, 1, 1, 1) + data = torch.randn(3, 3) # instantiate M and RefM and align the parameters original_m = M().eval() original_ref_m = RefM().eval() - original_ref_m.conv1.weight = torch.nn.Parameter(original_m.conv.weight.detach()) - original_ref_m.conv1.bias = torch.nn.Parameter(original_m.conv.bias.detach()) - original_ref_m.conv2.weight = torch.nn.Parameter(original_m.custom.conv.weight.detach()) - original_ref_m.conv2.bias = torch.nn.Parameter(original_m.custom.conv.bias.detach()) + original_ref_m.linear1.weight = torch.nn.Parameter(original_m.linear.weight.detach()) + original_ref_m.linear1.bias = torch.nn.Parameter(original_m.linear.bias.detach()) + original_ref_m.linear2.weight = torch.nn.Parameter(original_m.custom.linear.weight.detach()) + original_ref_m.linear2.bias = torch.nn.Parameter(original_m.custom.linear.bias.detach()) + + test_configs = { + "static": (default_qconfig, StaticQuantCustomModule, 3), + "dynamic": (default_dynamic_qconfig, DynamicQuantCustomModule, 0) + } - # TODO: add other quant types after mixed mode support - for quant_type in [QuantType.STATIC]: - qconfig_dict = { - "": default_qconfig, - } - prepare_custom_config_dict = { - "float_to_observed_custom_module_class": { - CustomModule: ObservedCustomModule + for quant_type in [QuantType.DYNAMIC]: + key = quant_type_to_str(quant_type) + qconfig, quantized_module_class, num_observers = test_configs[key] + qconfig_dict = {"": qconfig} + if key == "static": + prepare_custom_config_dict = { + "float_to_observed_custom_module_class": { + "static": { + CustomModule: ObservedCustomModule + } + } } - } - convert_custom_config_dict = { - "observed_to_quantized_custom_module_class": { - ObservedCustomModule: QuantizedCustomModule + convert_custom_config_dict = { + "observed_to_quantized_custom_module_class": { + "static": { + ObservedCustomModule: quantized_module_class + } + } } - } + else: + prepare_custom_config_dict = { + "non_traceable_module_class": [ + CustomModule + ] + } + convert_custom_config_dict = { + "observed_to_quantized_custom_module_class": { + "dynamic": { + CustomModule: quantized_module_class + } + } + } + # check prepared model m = prepare_fx( original_m, @@ -739,7 +762,7 @@ def forward(self, x): m(data) # all activation observers are inserted in the top level module count_check = { - ns.call_module(torch.quantization.MinMaxObserver): 3 + ns.call_module(torch.quantization.MinMaxObserver): num_observers } self.checkGraphModuleNodes(m, expected_node_occurrence=count_check) @@ -747,12 +770,14 @@ def forward(self, x): m = convert_fx( m, convert_custom_config_dict=convert_custom_config_dict) - count_check = { - ns.call_function(torch.quantize_per_tensor) : 1, - ns.call_module(nnq.Conv2d) : 1, - ns.call_method('dequantize') : 1, - } - self.checkGraphModuleNodes(m, expected_node_occurrence=count_check) + if quant_type == QuantType.STATIC: + count_check = { + ns.call_function(torch.quantize_per_tensor) : 1, + ns.call_module(nnq.Linear) : 1, + ns.call_method('dequantize') : 1, + } + self.checkGraphModuleNodes(m, expected_node_occurrence=count_check) + self.assertEqual(type(m.custom), quantized_module_class) res = m(data) # quantize the reference model diff --git a/torch/quantization/__init__.py b/torch/quantization/__init__.py index ee506b6fc6a7..d6daa79fae53 100644 --- a/torch/quantization/__init__.py +++ b/torch/quantization/__init__.py @@ -28,7 +28,7 @@ def default_eval_fn(model, calib_data): # Top level API for graph mode quantization on GraphModule(torch.fx) 'fuse_fx', 'quantize_fx', # TODO: add quantize_dynamic_fx 'prepare_fx', 'prepare_dynamic_fx', 'convert_fx', - 'QuantType', # quantization type + 'QuantType', 'quant_type_to_str', # quantization type # custom module APIs 'get_default_static_quant_module_mappings', 'get_static_quant_module_class', 'get_default_dynamic_quant_module_mappings', diff --git a/torch/quantization/fx/quantization_patterns.py b/torch/quantization/fx/quantization_patterns.py index 12787e4a87db..91398d857825 100644 --- a/torch/quantization/fx/quantization_patterns.py +++ b/torch/quantization/fx/quantization_patterns.py @@ -16,6 +16,7 @@ _parent_name, quantize_node, get_per_tensor_qparams, + get_swapped_custom_module_class, activation_is_statically_quantized, weight_is_quantized, weight_dtype, @@ -176,6 +177,12 @@ def __init__(self, quantizer, node): def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_dict=None): # TODO: debug option for conv module + qconfig = quantizer.qconfig_map[node.name] + activation_statically_quantized = activation_is_statically_quantized(qconfig) + # only static qunatization (for both ptq and qat) is supported for conv + if not activation_statically_quantized: + return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None)) + if self.conv_node.op == 'call_module': # note that relu should already be fused into conv module in the fusion step assert self.relu_node is None, 'conv module and relu fusion is not executed, ' \ @@ -587,13 +594,14 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_ assert convert_custom_config_dict is not None custom_module_class_mapping = convert_custom_config_dict.get("observed_to_quantized_custom_module_class", None) assert custom_module_class_mapping is not None + qconfig = quantizer.qconfig_map[node.name] observed_custom_module = quantizer.modules[node.target] - if node.name in quantizer.activation_post_process_map: + if activation_is_statically_quantized(qconfig): + assert node.name in quantizer.activation_post_process_map observed_custom_module.activation_post_process = \ quantizer.activation_post_process_map[node.name] - quantized_custom_module_class = custom_module_class_mapping.get(type(observed_custom_module), None) - assert quantized_custom_module_class is not None, "did not found quantized custom module for:" + \ - str(type(observed_custom_module)) + quantized_custom_module_class = get_swapped_custom_module_class( + observed_custom_module, custom_module_class_mapping, qconfig) quantized_custom_module = \ quantized_custom_module_class.from_observed(observed_custom_module) parent_name, name = _parent_name(node.target) diff --git a/torch/quantization/fx/quantize.py b/torch/quantization/fx/quantize.py index 3eebef4ff10a..cae47ced07f7 100644 --- a/torch/quantization/fx/quantize.py +++ b/torch/quantization/fx/quantize.py @@ -36,6 +36,8 @@ from .utils import ( _parent_name, quantize_node, + get_custom_module_class_keys, + get_swapped_custom_module_class, activation_is_statically_quantized, ) @@ -347,9 +349,9 @@ def _prepare(self, model, qconfig_dict, inplace, prepare_custom_config_dict, is_ # match the patterns that will get quantized standalone_module_names = prepare_custom_config_dict.get("standalone_module_name", None) - custom_module_class_mapping = prepare_custom_config_dict.get("float_to_observed_custom_module_class", None) + custom_module_classes = get_custom_module_class_keys(prepare_custom_config_dict, "float_to_observed_custom_module_class") matches = self._find_matches( - model.graph, self.modules, self.patterns, standalone_module_names, custom_module_class_mapping) + model.graph, self.modules, self.patterns, standalone_module_names, custom_module_classes) # find _inputs_ to matched nodes that are not quantized, these # have to be quantized, which requires measuring stats, @@ -403,8 +405,9 @@ def insert_observer(node, observer, device): if isinstance(obj, CustomModuleQuantizeHandler): custom_module = self.modules[node.target] + custom_module_class_mapping = prepare_custom_config_dict.get("float_to_observed_custom_module_class", {}) observed_custom_module_class = \ - custom_module_class_mapping[type(custom_module)] + get_swapped_custom_module_class(custom_module, custom_module_class_mapping, qconfig) observed_custom_module = \ observed_custom_module_class.from_float(custom_module) parent_name, name = _parent_name(node.target) @@ -569,10 +572,11 @@ def _convert(self, model, inplace=False, debug=False, convert_custom_config_dict model.eval().cpu() self.modules = dict(model.named_modules()) - custom_module_class_mapping = convert_custom_config_dict.get("observed_to_quantized_custom_module_class", None) + custom_module_classes = get_custom_module_class_keys( + convert_custom_config_dict, "observed_to_quantized_custom_module_class") matches = self._find_matches( model.graph, self.modules, self.patterns, - custom_module_class_mapping=custom_module_class_mapping) + custom_module_classes=custom_module_classes) quants = self._find_quants(model.graph, matches) @@ -818,7 +822,7 @@ def convert(self, model, inplace=False, debug=False, convert_custom_config_dict= def _find_matches( self, graph, modules, patterns, - standalone_module_names=None, custom_module_class_mapping=None): + standalone_module_names=None, custom_module_classes=None): """ Matches the nodes in the input graph to quantization patterns, and outputs the information needed to quantize them in future steps. @@ -839,8 +843,8 @@ def _find_matches( ... } """ - if custom_module_class_mapping is None: - custom_module_class_mapping = {} + if custom_module_classes is None: + custom_module_classes = [] match_map = {} all_matched = set() @@ -870,7 +874,7 @@ def record_match(pattern, node, matched): # add custom module instances to the match result for node in graph.nodes: if node.op == 'call_module' and \ - type(self.modules[node.target]) in custom_module_class_mapping: + type(self.modules[node.target]) in custom_module_classes: custom_module_qconfig = self.qconfig_map[node.name] match_map[node.name] = ( node, [node], CustomModuleQuantizeHandler(self, node), custom_module_qconfig) diff --git a/torch/quantization/fx/utils.py b/torch/quantization/fx/utils.py index 98f94a0633a0..366970cec4c0 100644 --- a/torch/quantization/fx/utils.py +++ b/torch/quantization/fx/utils.py @@ -1,5 +1,6 @@ import re import torch +from ..quant_type import QuantType, quant_type_to_str # turn foo.bar -> ['foo', 'bar'] def _parent_name(target): @@ -139,6 +140,55 @@ def get_next_qparams_idx(module, qparams): inputs.append(graph.create_node('get_attr', qparam_full_path)) return graph.create_node('call_function', quantize_op, tuple(inputs), {}) +def get_custom_module_class_keys(custom_config_dict, custom_config_dict_key): + r""" Get all the unique custom module keys in the custom config dict + e.g. + Input: + custom_config_dict = { + "float_to_observed_custom_module_class": { + "static": { + CustomModule1: ObservedCustomModule + }, + "dynamic": { + CustomModule2: DynamicObservedCustomModule + }, + "weight_only": { + CustomModule3: WeightOnlyObservedCustomModule + }, + }, + } + + Output: + # extract all the keys in "static", "dynamic" and "weight_only" dict + [CustomModule1, CustomModule2, CustomModule3] + """ + # using set to dedup + float_custom_module_classes = set() + custom_module_mapping = custom_config_dict.get(custom_config_dict_key, {}) + for quant_mode in ["static", "dynamic", "weight_only"]: + quant_mode_custom_module_config = custom_module_mapping.get(quant_mode, {}) + quant_mode_custom_module_classes = set(quant_mode_custom_module_config.keys()) + float_custom_module_classes |= quant_mode_custom_module_classes + return list(float_custom_module_classes) + +def get_swapped_custom_module_class(custom_module, custom_module_class_mapping, qconfig): + """ Get the observed/quantized custom module class that we need + to swap `custom_module` to + Input: + custom_module: input, can be an instance of either a float or observed custom module + custom_module_class_mapping: the float to observed or observed to quantized custom module class mapping + qconfig: qconfig configured for the custom module + + Output: + corresponding observed/quantized custom module class for input custom module instance + """ + quant_type = get_quant_type(qconfig) + quant_type_str = quant_type_to_str(quant_type) + class_mapping = custom_module_class_mapping.get(quant_type_str, {}) + assert type(custom_module) in class_mapping, "did not found corresponding observed " \ + "module class for {} in mapping: {}".format(type(custom_module), class_mapping) + return class_mapping[type(custom_module)] + def activation_is_statically_quantized(qconfig): """ Given a qconfig, decide if the activation needs to be statically quantized or not @@ -158,6 +208,26 @@ def weight_is_quantized(qconfig): """ return weight_dtype(qconfig) in [torch.quint8, torch.qint8] +def get_quant_type(qconfig): + assert qconfig is not None + activation = qconfig.activation() + weight = qconfig.weight() + static_dtypes = [torch.quint8, torch.qint8] + if weight.dtype in static_dtypes: + if activation.dtype in static_dtypes: + return QuantType.STATIC + elif hasattr(activation, 'compute_dtype') and activation.compute_dtype in static_dtypes: + return QuantType.DYNAMIC + else: + return QuantType.WEIGHT_ONLY + + if weight.dtype == torch.float16: + if activation.dtype == torch.float: + return QuantType.WEIGHT_ONLY + + raise Exception("Unrecognized dtype combination in get_quant_type: activation({})," + "weight({})".format(activation.dtype, weight.dtype)) + def get_linear_prepack_op_for_dtype(dtype): if dtype == torch.float16: return torch.ops.quantized.linear_prepack_fp16 diff --git a/torch/quantization/quant_type.py b/torch/quantization/quant_type.py index 212dec1fe28c..463d086b39b6 100644 --- a/torch/quantization/quant_type.py +++ b/torch/quantization/quant_type.py @@ -1,4 +1,3 @@ - import enum # Quantization type (dynamic quantization, static quantization). @@ -7,3 +6,14 @@ class QuantType(enum.IntEnum): DYNAMIC = 0 STATIC = 1 QAT = 2 + WEIGHT_ONLY = 3 + + +def quant_type_to_str(quant_type): + m = { + QuantType.STATIC: "static", + QuantType.DYNAMIC: "dynamic", + QuantType.QAT: "qat", + QuantType.WEIGHT_ONLY: "weight_only", + } + return m[quant_type] diff --git a/torch/quantization/quantize_fx.py b/torch/quantization/quantize_fx.py index 001c797ad6b1..e3a116be1785 100644 --- a/torch/quantization/quantize_fx.py +++ b/torch/quantization/quantize_fx.py @@ -4,6 +4,7 @@ from .fx import Fuser # noqa: F401 from .fx import Quantizer # noqa: F401 from .fx.utils import graph_pretty_str # noqa: F401 +from .fx.utils import get_custom_module_class_keys # noqa: F401 def _check_is_graph_module(model): if not isinstance(model, GraphModule): @@ -74,9 +75,9 @@ def _prepare_fx(model, qconfig_dict, inplace, prepare_custom_config_dict=None, i # standalone module and custom module config are applied in top level module standalone_module_names = prepare_custom_config_dict.get('standalone_module_name', []) skipped_module_names += standalone_module_names - custom_module_config = prepare_custom_config_dict.get('float_to_observed_custom_module_class', {}) - custom_module_classes = list(custom_module_config.keys()) - skipped_module_classes += custom_module_classes + float_custom_module_classes = get_custom_module_class_keys( + prepare_custom_config_dict, "float_to_observed_custom_module_class") + skipped_module_classes += float_custom_module_classes tracer = CustomTracer(skipped_module_names, skipped_module_classes) graph_module = GraphModule(model, tracer.trace(model)) graph_module = _fuse_fx(graph_module, inplace, prepare_custom_config_dict) @@ -178,8 +179,11 @@ def prepare_fx(model, qconfig_dict, inplace=False, prepare_custom_config_dict=No # user will manually define the corresponding observed # module class which has a from_float class method that converts # float custom module to observed custom module + # (only needed for static quantization) "float_to_observed_custom_module_class": { - CustomModule: ObservedCustomModule + "static": { + CustomModule: ObservedCustomModule + } }, # the qualified names for the submodule that are not symbolically traceable @@ -188,6 +192,7 @@ def prepare_fx(model, qconfig_dict, inplace=False, prepare_custom_config_dict=No ], # the module classes that are not symbolically traceable + # we'll also put dynamic/weight_only custom module here "non_traceable_module_class": [ NonTraceableModule ], @@ -313,7 +318,15 @@ def convert_fx(graph_module, inplace=False, debug=False, convert_custom_config_d # module class which has a from_observed class method that converts # observed custom module to quantized custom module "observed_to_quantized_custom_module_class": { - ObservedCustomModule: QuantizedCustomModule + "static": { + ObservedCustomModule: QuantizedCustomModule + }, + "dynamic": { + ObservedCustomModule: QuantizedCustomModule + }, + "weight_only": { + ObservedCustomModule: QuantizedCustomModule + } } }