diff --git a/test/quantization/test_quantize_fx.py b/test/quantization/test_quantize_fx.py index 4e352a80daa3..be1e5a6cb6e8 100644 --- a/test/quantization/test_quantize_fx.py +++ b/test/quantization/test_quantize_fx.py @@ -9,6 +9,7 @@ # graph mode quantization based on fx from torch.quantization import ( QuantType, + quant_type_to_str, prepare_fx, convert_fx, prepare_qat_fx, @@ -632,104 +633,126 @@ def test_custom_module_class(self): class CustomModule(torch.nn.Module): def __init__(self): super().__init__() - self.conv = torch.nn.Conv2d(1, 1, 1) + self.linear = torch.nn.Linear(3, 3) def forward(self, x): - return self.conv(x) + return self.linear(x) class ObservedCustomModule(torch.nn.Module): - def __init__(self, conv): + def __init__(self, linear): super().__init__() - self.conv = conv + self.linear = linear def forward(self, x): - return self.conv(x) + return self.linear(x) @classmethod def from_float(cls, float_module): assert hasattr(float_module, 'qconfig') - observed = cls(float_module.conv) + observed = cls(float_module.linear) observed.qconfig = float_module.qconfig return observed - class QuantizedCustomModule(torch.nn.Module): - def __init__(self, conv): + class StaticQuantCustomModule(torch.nn.Module): + def __init__(self, linear): super().__init__() - self.conv = conv + self.linear = linear def forward(self, x): - return self.conv(x) + return self.linear(x) @classmethod def from_observed(cls, observed_module): assert hasattr(observed_module, 'qconfig') assert hasattr(observed_module, 'activation_post_process') - observed_module.conv.activation_post_process = \ + observed_module.linear.activation_post_process = \ observed_module.activation_post_process - quantized = cls(nnq.Conv2d.from_float(observed_module.conv)) + quantized = cls(nnq.Linear.from_float(observed_module.linear)) return quantized - class DynamicallyQuantizedCustomModule(torch.nn.Module): - def __init__(self, conv): + class DynamicQuantCustomModule(torch.nn.Module): + def __init__(self, linear): super().__init__() - self.conv = conv + self.linear = linear def forward(self, x): - return self.conv(x) + return self.linear(x) @classmethod def from_observed(cls, observed_module): assert hasattr(observed_module, 'qconfig') - assert hasattr(observed_module, 'activation_post_process') - quantized = cls(nnqd.Conv2d.from_float(observed_module.conv)) + quantized = cls(nnqd.Linear.from_float(observed_module.linear)) return quantized class M(torch.nn.Module): def __init__(self): super().__init__() - self.conv = torch.nn.Conv2d(1, 1, 1) + self.linear = torch.nn.Linear(3, 3) self.custom = CustomModule() def forward(self, x): - x = self.conv(x) + x = self.linear(x) x = self.custom(x) return x class RefM(torch.nn.Module): def __init__(self): super().__init__() - self.conv1 = torch.nn.Conv2d(1, 1, 1) - self.conv2 = torch.nn.Conv2d(1, 1, 1) + self.linear1 = torch.nn.Linear(3, 3) + self.linear2 = torch.nn.Linear(3, 3) def forward(self, x): - x = self.conv1(x) - x = self.conv2(x) + x = self.linear1(x) + x = self.linear2(x) return x - data = torch.randn(1, 1, 1, 1) + data = torch.randn(3, 3) # instantiate M and RefM and align the parameters original_m = M().eval() original_ref_m = RefM().eval() - original_ref_m.conv1.weight = torch.nn.Parameter(original_m.conv.weight.detach()) - original_ref_m.conv1.bias = torch.nn.Parameter(original_m.conv.bias.detach()) - original_ref_m.conv2.weight = torch.nn.Parameter(original_m.custom.conv.weight.detach()) - original_ref_m.conv2.bias = torch.nn.Parameter(original_m.custom.conv.bias.detach()) + original_ref_m.linear1.weight = torch.nn.Parameter(original_m.linear.weight.detach()) + original_ref_m.linear1.bias = torch.nn.Parameter(original_m.linear.bias.detach()) + original_ref_m.linear2.weight = torch.nn.Parameter(original_m.custom.linear.weight.detach()) + original_ref_m.linear2.bias = torch.nn.Parameter(original_m.custom.linear.bias.detach()) + + test_configs = { + "static": (default_qconfig, StaticQuantCustomModule, 3), + "dynamic": (default_dynamic_qconfig, DynamicQuantCustomModule, 0) + } - # TODO: add other quant types after mixed mode support - for quant_type in [QuantType.STATIC]: - qconfig_dict = { - "": default_qconfig, - } - prepare_custom_config_dict = { - "float_to_observed_custom_module_class": { - CustomModule: ObservedCustomModule + for quant_type in [QuantType.DYNAMIC]: + key = quant_type_to_str(quant_type) + qconfig, quantized_module_class, num_observers = test_configs[key] + qconfig_dict = {"": qconfig} + if key == "static": + prepare_custom_config_dict = { + "float_to_observed_custom_module_class": { + "static": { + CustomModule: ObservedCustomModule + } + } } - } - convert_custom_config_dict = { - "observed_to_quantized_custom_module_class": { - ObservedCustomModule: QuantizedCustomModule + convert_custom_config_dict = { + "observed_to_quantized_custom_module_class": { + "static": { + ObservedCustomModule: quantized_module_class + } + } } - } + else: + prepare_custom_config_dict = { + "non_traceable_module_class": [ + CustomModule + ] + } + convert_custom_config_dict = { + "observed_to_quantized_custom_module_class": { + "dynamic": { + CustomModule: quantized_module_class + } + } + } + # check prepared model m = prepare_fx( original_m, @@ -739,7 +762,7 @@ def forward(self, x): m(data) # all activation observers are inserted in the top level module count_check = { - ns.call_module(torch.quantization.MinMaxObserver): 3 + ns.call_module(torch.quantization.MinMaxObserver): num_observers } self.checkGraphModuleNodes(m, expected_node_occurrence=count_check) @@ -747,12 +770,14 @@ def forward(self, x): m = convert_fx( m, convert_custom_config_dict=convert_custom_config_dict) - count_check = { - ns.call_function(torch.quantize_per_tensor) : 1, - ns.call_module(nnq.Conv2d) : 1, - ns.call_method('dequantize') : 1, - } - self.checkGraphModuleNodes(m, expected_node_occurrence=count_check) + if quant_type == QuantType.STATIC: + count_check = { + ns.call_function(torch.quantize_per_tensor) : 1, + ns.call_module(nnq.Linear) : 1, + ns.call_method('dequantize') : 1, + } + self.checkGraphModuleNodes(m, expected_node_occurrence=count_check) + self.assertEqual(type(m.custom), quantized_module_class) res = m(data) # quantize the reference model diff --git a/torch/quantization/__init__.py b/torch/quantization/__init__.py index ee506b6fc6a7..d6daa79fae53 100644 --- a/torch/quantization/__init__.py +++ b/torch/quantization/__init__.py @@ -28,7 +28,7 @@ def default_eval_fn(model, calib_data): # Top level API for graph mode quantization on GraphModule(torch.fx) 'fuse_fx', 'quantize_fx', # TODO: add quantize_dynamic_fx 'prepare_fx', 'prepare_dynamic_fx', 'convert_fx', - 'QuantType', # quantization type + 'QuantType', 'quant_type_to_str', # quantization type # custom module APIs 'get_default_static_quant_module_mappings', 'get_static_quant_module_class', 'get_default_dynamic_quant_module_mappings', diff --git a/torch/quantization/fx/quantization_patterns.py b/torch/quantization/fx/quantization_patterns.py index 12787e4a87db..91398d857825 100644 --- a/torch/quantization/fx/quantization_patterns.py +++ b/torch/quantization/fx/quantization_patterns.py @@ -16,6 +16,7 @@ _parent_name, quantize_node, get_per_tensor_qparams, + get_swapped_custom_module_class, activation_is_statically_quantized, weight_is_quantized, weight_dtype, @@ -176,6 +177,12 @@ def __init__(self, quantizer, node): def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_dict=None): # TODO: debug option for conv module + qconfig = quantizer.qconfig_map[node.name] + activation_statically_quantized = activation_is_statically_quantized(qconfig) + # only static qunatization (for both ptq and qat) is supported for conv + if not activation_statically_quantized: + return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None)) + if self.conv_node.op == 'call_module': # note that relu should already be fused into conv module in the fusion step assert self.relu_node is None, 'conv module and relu fusion is not executed, ' \ @@ -587,13 +594,14 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_ assert convert_custom_config_dict is not None custom_module_class_mapping = convert_custom_config_dict.get("observed_to_quantized_custom_module_class", None) assert custom_module_class_mapping is not None + qconfig = quantizer.qconfig_map[node.name] observed_custom_module = quantizer.modules[node.target] - if node.name in quantizer.activation_post_process_map: + if activation_is_statically_quantized(qconfig): + assert node.name in quantizer.activation_post_process_map observed_custom_module.activation_post_process = \ quantizer.activation_post_process_map[node.name] - quantized_custom_module_class = custom_module_class_mapping.get(type(observed_custom_module), None) - assert quantized_custom_module_class is not None, "did not found quantized custom module for:" + \ - str(type(observed_custom_module)) + quantized_custom_module_class = get_swapped_custom_module_class( + observed_custom_module, custom_module_class_mapping, qconfig) quantized_custom_module = \ quantized_custom_module_class.from_observed(observed_custom_module) parent_name, name = _parent_name(node.target) diff --git a/torch/quantization/fx/quantize.py b/torch/quantization/fx/quantize.py index 3eebef4ff10a..cae47ced07f7 100644 --- a/torch/quantization/fx/quantize.py +++ b/torch/quantization/fx/quantize.py @@ -36,6 +36,8 @@ from .utils import ( _parent_name, quantize_node, + get_custom_module_class_keys, + get_swapped_custom_module_class, activation_is_statically_quantized, ) @@ -347,9 +349,9 @@ def _prepare(self, model, qconfig_dict, inplace, prepare_custom_config_dict, is_ # match the patterns that will get quantized standalone_module_names = prepare_custom_config_dict.get("standalone_module_name", None) - custom_module_class_mapping = prepare_custom_config_dict.get("float_to_observed_custom_module_class", None) + custom_module_classes = get_custom_module_class_keys(prepare_custom_config_dict, "float_to_observed_custom_module_class") matches = self._find_matches( - model.graph, self.modules, self.patterns, standalone_module_names, custom_module_class_mapping) + model.graph, self.modules, self.patterns, standalone_module_names, custom_module_classes) # find _inputs_ to matched nodes that are not quantized, these # have to be quantized, which requires measuring stats, @@ -403,8 +405,9 @@ def insert_observer(node, observer, device): if isinstance(obj, CustomModuleQuantizeHandler): custom_module = self.modules[node.target] + custom_module_class_mapping = prepare_custom_config_dict.get("float_to_observed_custom_module_class", {}) observed_custom_module_class = \ - custom_module_class_mapping[type(custom_module)] + get_swapped_custom_module_class(custom_module, custom_module_class_mapping, qconfig) observed_custom_module = \ observed_custom_module_class.from_float(custom_module) parent_name, name = _parent_name(node.target) @@ -569,10 +572,11 @@ def _convert(self, model, inplace=False, debug=False, convert_custom_config_dict model.eval().cpu() self.modules = dict(model.named_modules()) - custom_module_class_mapping = convert_custom_config_dict.get("observed_to_quantized_custom_module_class", None) + custom_module_classes = get_custom_module_class_keys( + convert_custom_config_dict, "observed_to_quantized_custom_module_class") matches = self._find_matches( model.graph, self.modules, self.patterns, - custom_module_class_mapping=custom_module_class_mapping) + custom_module_classes=custom_module_classes) quants = self._find_quants(model.graph, matches) @@ -818,7 +822,7 @@ def convert(self, model, inplace=False, debug=False, convert_custom_config_dict= def _find_matches( self, graph, modules, patterns, - standalone_module_names=None, custom_module_class_mapping=None): + standalone_module_names=None, custom_module_classes=None): """ Matches the nodes in the input graph to quantization patterns, and outputs the information needed to quantize them in future steps. @@ -839,8 +843,8 @@ def _find_matches( ... } """ - if custom_module_class_mapping is None: - custom_module_class_mapping = {} + if custom_module_classes is None: + custom_module_classes = [] match_map = {} all_matched = set() @@ -870,7 +874,7 @@ def record_match(pattern, node, matched): # add custom module instances to the match result for node in graph.nodes: if node.op == 'call_module' and \ - type(self.modules[node.target]) in custom_module_class_mapping: + type(self.modules[node.target]) in custom_module_classes: custom_module_qconfig = self.qconfig_map[node.name] match_map[node.name] = ( node, [node], CustomModuleQuantizeHandler(self, node), custom_module_qconfig) diff --git a/torch/quantization/fx/utils.py b/torch/quantization/fx/utils.py index 98f94a0633a0..366970cec4c0 100644 --- a/torch/quantization/fx/utils.py +++ b/torch/quantization/fx/utils.py @@ -1,5 +1,6 @@ import re import torch +from ..quant_type import QuantType, quant_type_to_str # turn foo.bar -> ['foo', 'bar'] def _parent_name(target): @@ -139,6 +140,55 @@ def get_next_qparams_idx(module, qparams): inputs.append(graph.create_node('get_attr', qparam_full_path)) return graph.create_node('call_function', quantize_op, tuple(inputs), {}) +def get_custom_module_class_keys(custom_config_dict, custom_config_dict_key): + r""" Get all the unique custom module keys in the custom config dict + e.g. + Input: + custom_config_dict = { + "float_to_observed_custom_module_class": { + "static": { + CustomModule1: ObservedCustomModule + }, + "dynamic": { + CustomModule2: DynamicObservedCustomModule + }, + "weight_only": { + CustomModule3: WeightOnlyObservedCustomModule + }, + }, + } + + Output: + # extract all the keys in "static", "dynamic" and "weight_only" dict + [CustomModule1, CustomModule2, CustomModule3] + """ + # using set to dedup + float_custom_module_classes = set() + custom_module_mapping = custom_config_dict.get(custom_config_dict_key, {}) + for quant_mode in ["static", "dynamic", "weight_only"]: + quant_mode_custom_module_config = custom_module_mapping.get(quant_mode, {}) + quant_mode_custom_module_classes = set(quant_mode_custom_module_config.keys()) + float_custom_module_classes |= quant_mode_custom_module_classes + return list(float_custom_module_classes) + +def get_swapped_custom_module_class(custom_module, custom_module_class_mapping, qconfig): + """ Get the observed/quantized custom module class that we need + to swap `custom_module` to + Input: + custom_module: input, can be an instance of either a float or observed custom module + custom_module_class_mapping: the float to observed or observed to quantized custom module class mapping + qconfig: qconfig configured for the custom module + + Output: + corresponding observed/quantized custom module class for input custom module instance + """ + quant_type = get_quant_type(qconfig) + quant_type_str = quant_type_to_str(quant_type) + class_mapping = custom_module_class_mapping.get(quant_type_str, {}) + assert type(custom_module) in class_mapping, "did not found corresponding observed " \ + "module class for {} in mapping: {}".format(type(custom_module), class_mapping) + return class_mapping[type(custom_module)] + def activation_is_statically_quantized(qconfig): """ Given a qconfig, decide if the activation needs to be statically quantized or not @@ -158,6 +208,26 @@ def weight_is_quantized(qconfig): """ return weight_dtype(qconfig) in [torch.quint8, torch.qint8] +def get_quant_type(qconfig): + assert qconfig is not None + activation = qconfig.activation() + weight = qconfig.weight() + static_dtypes = [torch.quint8, torch.qint8] + if weight.dtype in static_dtypes: + if activation.dtype in static_dtypes: + return QuantType.STATIC + elif hasattr(activation, 'compute_dtype') and activation.compute_dtype in static_dtypes: + return QuantType.DYNAMIC + else: + return QuantType.WEIGHT_ONLY + + if weight.dtype == torch.float16: + if activation.dtype == torch.float: + return QuantType.WEIGHT_ONLY + + raise Exception("Unrecognized dtype combination in get_quant_type: activation({})," + "weight({})".format(activation.dtype, weight.dtype)) + def get_linear_prepack_op_for_dtype(dtype): if dtype == torch.float16: return torch.ops.quantized.linear_prepack_fp16 diff --git a/torch/quantization/quant_type.py b/torch/quantization/quant_type.py index 212dec1fe28c..463d086b39b6 100644 --- a/torch/quantization/quant_type.py +++ b/torch/quantization/quant_type.py @@ -1,4 +1,3 @@ - import enum # Quantization type (dynamic quantization, static quantization). @@ -7,3 +6,14 @@ class QuantType(enum.IntEnum): DYNAMIC = 0 STATIC = 1 QAT = 2 + WEIGHT_ONLY = 3 + + +def quant_type_to_str(quant_type): + m = { + QuantType.STATIC: "static", + QuantType.DYNAMIC: "dynamic", + QuantType.QAT: "qat", + QuantType.WEIGHT_ONLY: "weight_only", + } + return m[quant_type] diff --git a/torch/quantization/quantize_fx.py b/torch/quantization/quantize_fx.py index 001c797ad6b1..e3a116be1785 100644 --- a/torch/quantization/quantize_fx.py +++ b/torch/quantization/quantize_fx.py @@ -4,6 +4,7 @@ from .fx import Fuser # noqa: F401 from .fx import Quantizer # noqa: F401 from .fx.utils import graph_pretty_str # noqa: F401 +from .fx.utils import get_custom_module_class_keys # noqa: F401 def _check_is_graph_module(model): if not isinstance(model, GraphModule): @@ -74,9 +75,9 @@ def _prepare_fx(model, qconfig_dict, inplace, prepare_custom_config_dict=None, i # standalone module and custom module config are applied in top level module standalone_module_names = prepare_custom_config_dict.get('standalone_module_name', []) skipped_module_names += standalone_module_names - custom_module_config = prepare_custom_config_dict.get('float_to_observed_custom_module_class', {}) - custom_module_classes = list(custom_module_config.keys()) - skipped_module_classes += custom_module_classes + float_custom_module_classes = get_custom_module_class_keys( + prepare_custom_config_dict, "float_to_observed_custom_module_class") + skipped_module_classes += float_custom_module_classes tracer = CustomTracer(skipped_module_names, skipped_module_classes) graph_module = GraphModule(model, tracer.trace(model)) graph_module = _fuse_fx(graph_module, inplace, prepare_custom_config_dict) @@ -178,8 +179,11 @@ def prepare_fx(model, qconfig_dict, inplace=False, prepare_custom_config_dict=No # user will manually define the corresponding observed # module class which has a from_float class method that converts # float custom module to observed custom module + # (only needed for static quantization) "float_to_observed_custom_module_class": { - CustomModule: ObservedCustomModule + "static": { + CustomModule: ObservedCustomModule + } }, # the qualified names for the submodule that are not symbolically traceable @@ -188,6 +192,7 @@ def prepare_fx(model, qconfig_dict, inplace=False, prepare_custom_config_dict=No ], # the module classes that are not symbolically traceable + # we'll also put dynamic/weight_only custom module here "non_traceable_module_class": [ NonTraceableModule ], @@ -313,7 +318,15 @@ def convert_fx(graph_module, inplace=False, debug=False, convert_custom_config_d # module class which has a from_observed class method that converts # observed custom module to quantized custom module "observed_to_quantized_custom_module_class": { - ObservedCustomModule: QuantizedCustomModule + "static": { + ObservedCustomModule: QuantizedCustomModule + }, + "dynamic": { + ObservedCustomModule: QuantizedCustomModule + }, + "weight_only": { + ObservedCustomModule: QuantizedCustomModule + } } }