From 89b4899ea5363fd69872c0cabf0dedea2dc533c8 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 23 Dec 2020 22:34:54 -0800 Subject: [PATCH] [quant][graphmode][fx] Standalone module support {input/output}_quantized_idxs (#49754) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49754 This PR adds the support for {input/output}_quantized_idxs for standalone module. if input_quantized_idxs = [] and output_quantized_idxs = [], the standalone module will be expecting float input and produce float output, and will quantize the input and dequantize output internally if input_quantized_idxs = [0] and otuput_qiuantized_idxs = [0], the standalone module will be expecting quantized input and produce quantized output, the input will be quantized in the parent module, and output will be dequantized in the parent module as well, this is similar to current quantized modules like nn.quantized.Conv2d For more details, please see the test case Test Plan: python test/test_quantization.py TestQuantizeFx.test_standalone_module Imported from OSS Reviewed By: raghuramank100 Differential Revision: D25684692 fbshipit-source-id: 900360e01c0e35b26fe85f4a887dc1fd6f7bfb66 --- test/quantization/test_quantize_fx.py | 126 +++++++++++++---- torch/quantization/fx/observed_module.py | 10 +- .../quantization/fx/quantization_patterns.py | 4 +- torch/quantization/fx/quantize.py | 132 +++++++++++++----- torch/quantization/quantize_fx.py | 23 ++- 5 files changed, 221 insertions(+), 74 deletions(-) diff --git a/test/quantization/test_quantize_fx.py b/test/quantization/test_quantize_fx.py index 66324f928f04..0aba50779432 100644 --- a/test/quantization/test_quantize_fx.py +++ b/test/quantization/test_quantize_fx.py @@ -570,7 +570,16 @@ def forward(self, x): m = convert_fx(m) m(tensor_input) - def test_standalone_module(self): + def _test_standalone_module( + self, + interface_config, + prepare_count_check, + standalone_prepare_count_check, + convert_count_check, + standalone_convert_count_check): + """ Test standalone module with different quantized input/quantized output + configurations + """ class StandaloneModule(torch.nn.Module): def __init__(self): super().__init__() @@ -610,45 +619,32 @@ def forward(self, x): original_ref_m.conv2.weight = torch.nn.Parameter(original_m.standalone.conv.weight.detach()) original_ref_m.conv2.bias = torch.nn.Parameter(original_m.standalone.conv.bias.detach()) - qconfig_dict = {"": default_qconfig} - config_name = {"standalone_module_name": [("standalone", None, None)]} - config_class = {"standalone_module_class": [(StandaloneModule, None, None)]} - for prepare_config in [config_name, config_class]: + for is_name in [True, False]: + if is_name: + prepare_config = { + "standalone_module_name": [("standalone", None, interface_config)] + } + else: + prepare_config = { + "standalone_module_class": [(StandaloneModule, None, interface_config)] + } + original_m_copy = copy.deepcopy(original_m) original_ref_m_copy = copy.deepcopy(original_ref_m) + + qconfig_dict = {"": default_qconfig} # check prepared model m = prepare_fx( original_m_copy, qconfig_dict, prepare_custom_config_dict=prepare_config) # calibration m(data) - # input and output of first conv, observer for standalone module - # will be inserted in the standalone module itself - count_check = { - ns.call_module(torch.quantization.MinMaxObserver): 2 - } - self.checkGraphModuleNodes(m, expected_node_occurrence=count_check) - # for input and output of conv in the standalone module - count_check = { - ns.call_module(torch.quantization.MinMaxObserver): 2 - } - self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=count_check) + self.checkGraphModuleNodes(m, expected_node_occurrence=prepare_count_check) + self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_prepare_count_check) # check converted/quantized model m = convert_fx(m) - count_check = { - ns.call_function(torch.quantize_per_tensor) : 1, - ns.call_module(nnq.Conv2d) : 1, - ns.call_method('dequantize') : 1, - } - self.checkGraphModuleNodes(m, expected_node_occurrence=count_check) - count_check = { - # standalone module will take float as input and output - # so we'll see quantize and dequantize in the modoule - ns.call_function(torch.quantize_per_tensor) : 1, - ns.call_module(nnq.Conv2d): 1, - ns.call_method('dequantize') : 1, - } - self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=count_check) + self.checkGraphModuleNodes(m, expected_node_occurrence=convert_count_check) + self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_convert_count_check) res = m(data) # quantize the reference model @@ -658,6 +654,76 @@ def forward(self, x): ref_res = ref_m(data) self.assertEqual(res, ref_res) + def test_standalone_module_float_interface(self): + float_interface_config = { + "input_quantized_idxs": [], # float input + "output_quantized_idxs": [], # float output + } + interface_config = float_interface_config + # input and output of first conv, observer for standalone module + # will be inserted in the standalone module itself + prepare_count_check = { + ns.call_module(torch.quantization.MinMaxObserver): 2 + } + # for input and output of conv in the standalone module + standalone_prepare_count_check = { + ns.call_module(torch.quantization.MinMaxObserver): 2 + } + convert_count_check = { + ns.call_function(torch.quantize_per_tensor) : 1, + ns.call_module(nnq.Conv2d) : 1, + ns.call_method("dequantize") : 1, + } + standalone_convert_count_check = { + # standalone module will take float as input and output + # so we'll see quantize and dequantize in the modoule + ns.call_function(torch.quantize_per_tensor) : 1, + ns.call_module(nnq.Conv2d): 1, + ns.call_method("dequantize") : 1, + } + self._test_standalone_module( + interface_config, + prepare_count_check, + standalone_prepare_count_check, + convert_count_check, + standalone_convert_count_check) + + def test_standalone_module_quantized_interface(self): + quantized_interface_config = { + "input_quantized_idxs": [0], # quantized input + "output_quantized_idxs": [0], # quantized output + } + interface_config = quantized_interface_config + # observer for input and output of first conv + prepare_count_check = { + ns.call_module(torch.quantization.MinMaxObserver): 2 + } + # for output of conv in the standalone module + standalone_prepare_count_check = { + ns.call_module(torch.quantization.MinMaxObserver): 1 + } + convert_count_check = { + # quantizing input for conv + ns.call_function(torch.quantize_per_tensor) : 1, + ns.call_module(nnq.Conv2d) : 1, + # dequantizing output of standalone module + ns.call_method("dequantize") : 1, + } + standalone_convert_count_check = { + # quantization of input happens in parent module + # quantization of output happens in the quantized conv module + ns.call_function(torch.quantize_per_tensor) : 0, + ns.call_module(nnq.Conv2d): 1, + # dequantization for output happens in parent module + ns.call_method("dequantize") : 0, + } + self._test_standalone_module( + interface_config, + prepare_count_check, + standalone_prepare_count_check, + convert_count_check, + standalone_convert_count_check) + @skipIfNoFBGEMM def test_qconfig_none(self): class M(torch.nn.Module): diff --git a/torch/quantization/fx/observed_module.py b/torch/quantization/fx/observed_module.py index a95bc184fa10..808a3b36fb4a 100644 --- a/torch/quantization/fx/observed_module.py +++ b/torch/quantization/fx/observed_module.py @@ -2,11 +2,11 @@ import copy from torch.fx import GraphModule # type: ignore from torch.fx.graph import Graph -from typing import Union, Dict, Any +from typing import Union, Dict, Any, List class ObservedGraphModule(GraphModule): - def get_preserved_attr_names(self): + def get_preserved_attr_names(self) -> List[str]: return ['_activation_post_process_map', '_patterns', '_qconfig_map', @@ -35,6 +35,12 @@ def is_observed_module(module: Any) -> bool: return isinstance(module, ObservedGraphModule) class ObservedStandaloneGraphModule(ObservedGraphModule): + def get_preserved_attr_names(self) -> List[str] : + return super().get_preserved_attr_names() + [ + "_standalone_module_input_quantized_idxs", + "_standalone_module_output_quantized_idxs" + ] + def __deepcopy__(self, memo): fake_mod = torch.nn.Module() fake_mod.__dict__ = copy.deepcopy(self.__dict__) diff --git a/torch/quantization/fx/quantization_patterns.py b/torch/quantization/fx/quantization_patterns.py index a1e601332d4a..ed2f7e35659c 100644 --- a/torch/quantization/fx/quantization_patterns.py +++ b/torch/quantization/fx/quantization_patterns.py @@ -753,10 +753,10 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, qconfig = quantizer.qconfig_map[node.name] convert = torch.quantization.quantize_fx._convert_standalone_module_fx # type: ignore observed_standalone_module = quantizer.modules[node.target] + input_quantized_idxs = observed_standalone_module._standalone_module_input_quantized_idxs quantized_standalone_module = convert(observed_standalone_module, debug=debug) parent_name, name = _parent_name(node.target) # update the modules dict setattr(quantizer.modules[parent_name], name, quantized_standalone_module) quantizer.modules[node.target] = quantized_standalone_module - # standalone module takes float input - return quantizer.quantized_graph.node_copy(node, load_arg(quantized=False)) + return quantizer.quantized_graph.node_copy(node, load_arg(quantized=input_quantized_idxs)) diff --git a/torch/quantization/fx/quantize.py b/torch/quantization/fx/quantize.py index af9496a66a63..d821f9610b7f 100644 --- a/torch/quantization/fx/quantize.py +++ b/torch/quantization/fx/quantize.py @@ -102,14 +102,15 @@ def insert_observer( 'call_module', observer_name, (load_arg(node),), {}) observed_node_names_set.add(node.name) -def insert_observer_for_special_module( +def maybe_insert_observer_for_special_module( quantize_handler: QuantizeHandler, modules: Dict[str, torch.nn.Module], - prepare_custom_config_dict: Any, qconfig: Any, node: Node): + prepare_custom_config_dict: Any, qconfig: Any, node: Node) -> Optional[List[int]]: """ Insert observer for custom module and standalone module Returns: standalone_module_input_idxs: the indexs for inputs that needs to be observed by parent module """ assert modules is not None + standalone_module_input_idxs = None if isinstance(quantize_handler, CustomModuleQuantizeHandler): custom_module = modules[node.target] # type: ignore custom_module_class_mapping = prepare_custom_config_dict.get( @@ -129,19 +130,22 @@ def insert_observer_for_special_module( class_config_map = {x[0]: (x[1], x[2]) for x in standalone_module_class_configs} name_config_map = {x[0]: (x[1], x[2]) for x in standalone_module_name_configs} config = class_config_map.get(type(standalone_module), (None, None)) - config = name_config_map.get(node.target, (None, None)) - standalone_module_qconfig_dict = {"": qconfig} if config[0] is None else config[0] - standalone_prepare_config_dict = {} if config[1] is None else config[1] + config = name_config_map.get(node.target, config) + sm_qconfig_dict = {"": qconfig} if config[0] is None else config[0] + sm_prepare_config_dict = {} if config[1] is None else config[1] prepare = \ torch.quantization.quantize_fx._prepare_standalone_module_fx # type: ignore observed_standalone_module = \ - prepare(standalone_module, standalone_module_qconfig_dict, standalone_prepare_config_dict) + prepare(standalone_module, sm_qconfig_dict, sm_prepare_config_dict) + standalone_module_input_idxs = observed_standalone_module.\ + _standalone_module_input_quantized_idxs observed_standalone_module = mark_observed_standalone_module( observed_standalone_module) parent_name, name = _parent_name(node.target) setattr(modules[parent_name], name, observed_standalone_module) modules[node.target] = observed_standalone_module # type: ignore + return standalone_module_input_idxs def insert_observer_for_output_of_the_node( node: Node, @@ -155,7 +159,8 @@ def insert_observer_for_output_of_the_node( observed_graph: Graph, load_arg: Callable, observed_node_names_set: Set[str], - matched_nodes: Optional[List[Node]]): + matched_nodes: Optional[List[Node]], + standalone_module_input_idxs: Optional[List[int]]): """ Insert observer/fake_quantize module for output of the observed module if needed """ @@ -215,8 +220,11 @@ def input_is_observed(arg): observed_node_names_set.add(node.name) elif isinstance(quantize_handler, StandaloneModuleQuantizeHandler): - # output is observed in the standalone module - return + assert node.op == "call_module" + output_is_quantized = 0 in \ + modules[node.target]._standalone_module_output_quantized_idxs # type: ignore + if output_is_quantized: + observed_node_names_set.add(node.name) elif (quantize_handler.all_node_args and input_output_observed(quantize_handler)): # observer for outputs @@ -226,6 +234,16 @@ def input_is_observed(arg): activation_post_process_map, env, observed_graph, load_arg, observed_node_names_set) + # insert observer for input of standalone module + if standalone_module_input_idxs is not None: + for idx in standalone_module_input_idxs: + if node.args[idx].name not in observed_node_names_set: # type: ignore + new_observer = qconfig.activation() + insert_observer( + node, new_observer, model, + activation_post_process_map, env, observed_graph, + load_arg, observed_node_names_set) + def insert_observer_for_input_arg_of_observed_node( node: Node, observed_node_names_set: Set[str], quants: Dict[str, Tuple[DefaultQuantizeHandler, Callable]], @@ -373,10 +391,19 @@ def _prepare(self, model: GraphModule, qconfig_dict: Any, """ standalone_module means it a submodule that is not inlined in parent module, and will be quantized separately as one unit. - When we are preparing a standalone module: - both input and output are observed in prepared standalone module + How the standalone module is observed is specified by `input_quantized_idxs` and + `output_quantized_idxs` in the prepare_custom_config for the standalone module Returns: model(GraphModule): prepared standalone module + attributes: + _standalone_module_input_quantized_idxs(List[Int]): a list of + indexes for the graph input that is expected to be quantized, + same as input_quantized_idxs configuration provided + for the standalone module + _standalone_module_output_quantized_idxs(List[Int]): a list of + indexs for the graph output that is quantized + same as input_quantized_idxs configuration provided + for the standalone module """ if prepare_custom_config_dict is None: prepare_custom_config_dict = {} @@ -430,8 +457,6 @@ def _prepare(self, model: GraphModule, qconfig_dict: Any, def load_arg(a): return map_arg(a, lambda node: env[node.name]) - # indexes for the inputs that needs to be observed - standalone_module_observed_input_idxs: List[int] = [] graph_inputs = [] for node in model.graph.nodes: if node.op == 'placeholder': @@ -487,14 +512,15 @@ def load_arg(a): # parent if qconfig is not None: assert obj is not None - insert_observer_for_special_module( - obj, self.modules, prepare_custom_config_dict, qconfig, - node) + standalone_module_input_idxs = \ + maybe_insert_observer_for_special_module( + obj, self.modules, prepare_custom_config_dict, qconfig, + node) insert_observer_for_output_of_the_node( node, obj, qconfig, self.modules, model, pattern, self.activation_post_process_map, env, observed_graph, load_arg, observed_node_names_set, - matched_nodes) + matched_nodes, standalone_module_input_idxs) else: env[node.name] = observed_graph.node_copy(node, load_arg) @@ -516,6 +542,19 @@ def load_arg(a): model = GraphModule(model, observed_graph) self.save_state(model) model = mark_observed_module(model) + if is_standalone_module: + assert result_node is not None + assert isinstance(result_node.args[0], Node), \ + "standalone module only supports returning simple value currently"\ + "(not tuple, dict etc.)" + # indicator for whether output is observed or not. + # This used for correctly quantize standalone modules + output_is_observed = \ + result_node.args[0].name in observed_node_names_set + # these inputs are observed in parent + model._standalone_module_input_quantized_idxs = \ + input_quantized_idxs + model._standalone_module_output_quantized_idxs = output_quantized_idxs return model def save_state(self, observed: GraphModule) -> None: @@ -569,8 +608,10 @@ def _convert(self, model: GraphModule, debug: bool = False, """ standalone_module means it a submodule that is not inlined in parent module, and will be quantized separately as one unit. - Returns a quantized standalone module which accepts float input - and produces float output. + Returns a quantized standalone module, whether input/output is quantized is + specified by prepare_custom_config_dict, with + input_quantized_idxs, output_quantized_idxs, please + see docs for prepare_fx for details """ if convert_custom_config_dict is None: convert_custom_config_dict = {} @@ -627,36 +668,50 @@ def load_x(n: Node) -> Node: else: return env[n.name] - def load_arg(quantized: Optional[Union[List[Any], bool, Tuple[Any, ...]]] + def load_arg(quantized: Optional[Union[List[int], bool, Tuple[int, ...]]] ) -> Callable[[Node], Argument]: """ Input: quantized, which can be None, list, boolean or tuple - - if quantized is a list or tuple, then arg should be a list and - the args with corresponding indexes will be quantized - - if quantized is a boolean, then all args will be - quantized/not quantized - if quantized is None, then we'll load the node as long as it exists + - if quantized is a boolean, then all args will be + quantized/not quantized + - if quantized is an empty list or tuple, then it is the same as load_arg(quantized=False) + - if quantized is a list or tuple, then arg should be a list and + the args with corresponding indexes will be quantized Output: fn which takes arg_or_args, and loads them from the corresponding environment depending on the value of quantized. """ assert quantized is None or \ isinstance(quantized, (tuple, list, bool)), type(quantized) + if isinstance(quantized, (tuple, list)) and len(quantized) == 0: + # empty tuple or list means nothing is quantized + quantized = False def load_arg_impl(arg_or_args): - if quantized is None: + # we'll update the format of `quantized` + # to better match arg_or_args + updated_quantized: Optional[Union[List[int], bool, Tuple[int, ...]]] = quantized + + if isinstance(quantized, (tuple, list)) and \ + len(quantized) == 1 and isinstance(arg_or_args, Node): + # when argument is one Node instead of tuple, we just need to check + # 0 is in the quantized list + updated_quantized = 0 in quantized + + if updated_quantized is None: return map_arg(arg_or_args, load_x) - if isinstance(quantized, bool): + if isinstance(updated_quantized, bool): return map_arg( arg_or_args, - load_quantized if quantized else load_non_quantized) - elif isinstance(quantized, (tuple, list)): + load_quantized if updated_quantized else load_non_quantized) + elif isinstance(updated_quantized, (tuple, list)): assert isinstance(arg_or_args, (tuple, list)), arg_or_args loaded_args = [] # for now, we only support quantizing positional arguments for i, a in enumerate(arg_or_args): - if i in quantized: + if i in updated_quantized: loaded_args.append(map_arg(a, load_quantized)) else: loaded_args.append(map_arg(a, load_non_quantized)) @@ -690,10 +745,10 @@ def node_arg_is_quantized(node_arg: Any) -> bool: def is_output_quantized(node: Node, obj: QuantizeHandler) -> bool: """ Check if output node is quantized or not """ assert self.modules is not None - # by default the output is expected to be quantized + # by default the output for a quantizable node is expected to be quantized quantized = True - # Need to get correct quantized/non-quantized state for the output + # Need to get correct quantized/non-quantized state forn the output # of CopyNode if type(obj) in [ CopyNode, @@ -750,7 +805,7 @@ def insert_quantize_node(node: Node) -> None: "output_quantized_idxs", []) for node in model.graph.nodes: - if node.op == 'output': + if node.op == "output": cur_output_node_idx = output_node_seen_cnt output_node_seen_cnt += 1 if cur_output_node_idx in output_quantized_idxs: @@ -775,12 +830,19 @@ def insert_quantize_node(node: Node) -> None: quantized = False else: assert obj is not None + # We will get whether the output is quantized or not before + # convert for standalone module and after convert + # for non-standalone module, since _standalone_module_output_quantized_idxs + # is only available in observed standalone module + if is_observed_standalone_module_node: + out_quant_idxs = self.modules[node.target]._standalone_module_output_quantized_idxs + assert len(out_quant_idxs) <= 1, "Currently standalone only support one output" + quantized = 0 in out_quant_idxs + result = obj.convert( self, node, load_arg, debug=debug, convert_custom_config_dict=convert_custom_config_dict) - if is_observed_standalone_module_node: - quantized = False - else: + if not is_observed_standalone_module_node: quantized = is_output_quantized(node, obj) if quantized: diff --git a/torch/quantization/quantize_fx.py b/torch/quantization/quantize_fx.py index cba104b8f783..89ba877ffe78 100644 --- a/torch/quantization/quantize_fx.py +++ b/torch/quantization/quantize_fx.py @@ -107,8 +107,20 @@ def _prepare_standalone_module_fx( standalone_module means it a submodule that is not inlined in parent module, and will be quantized separately as one unit. - Both input and output of the module are observed in the - standalone module. + How the standalone module is observed is specified by `input_quantized_idxs` and + `output_quantized_idxs` in the prepare_custom_config for the standalone module + + Returns: + model(GraphModule): prepared standalone module + attributes: + _standalone_module_input_quantized_idxs(List[Int]): a list of + indexes for the graph input that is expected to be quantized, + same as input_quantized_idxs configuration provided + for the standalone module + _standalone_module_output_quantized_idxs(List[Int]): a list of + indexs for the graph output that is quantized + same as input_quantized_idxs configuration provided + for the standalone module """ return _prepare_fx(model, qconfig_dict, prepare_custom_config_dict, is_standalone_module=True) @@ -378,8 +390,9 @@ def _convert_standalone_module_fx( r""" [Internal use only] Convert a model produced by :func:`~torch.quantization.prepare_standalone_module_fx` and convert it to a quantized model - Return: - A quantized standalone module which accepts float input - and produces float output. + Returns a quantized standalone module, whether input/output is quantized is + specified by prepare_custom_config_dict, with + input_quantized_idxs, output_quantized_idxs, please + see docs for prepare_fx for details """ return _convert_fx(graph_module, debug, convert_custom_config_dict, is_standalone_module=True)