From 46cf6d332f075ed90d3baf21c32de51e4f304549 Mon Sep 17 00:00:00 2001 From: Mike Ruberry Date: Thu, 24 Dec 2020 15:49:01 -0800 Subject: [PATCH] Revert D25684692: [quant][graphmode][fx] Standalone module support {input/output}_quantized_idxs Test Plan: revert-hammer Differential Revision: D25684692 (https://github.com/pytorch/pytorch/commit/89b4899ea5363fd69872c0cabf0dedea2dc533c8) Original commit changeset: 900360e01c0e fbshipit-source-id: 8b65fa8fbc7b364fbddb5f23cc696cd9b7db98cd --- test/quantization/test_quantize_fx.py | 126 ++++------------- torch/quantization/fx/observed_module.py | 10 +- .../quantization/fx/quantization_patterns.py | 4 +- torch/quantization/fx/quantize.py | 132 +++++------------- torch/quantization/quantize_fx.py | 23 +-- 5 files changed, 74 insertions(+), 221 deletions(-) diff --git a/test/quantization/test_quantize_fx.py b/test/quantization/test_quantize_fx.py index 0aba50779432..66324f928f04 100644 --- a/test/quantization/test_quantize_fx.py +++ b/test/quantization/test_quantize_fx.py @@ -570,16 +570,7 @@ def forward(self, x): m = convert_fx(m) m(tensor_input) - def _test_standalone_module( - self, - interface_config, - prepare_count_check, - standalone_prepare_count_check, - convert_count_check, - standalone_convert_count_check): - """ Test standalone module with different quantized input/quantized output - configurations - """ + def test_standalone_module(self): class StandaloneModule(torch.nn.Module): def __init__(self): super().__init__() @@ -619,32 +610,45 @@ def forward(self, x): original_ref_m.conv2.weight = torch.nn.Parameter(original_m.standalone.conv.weight.detach()) original_ref_m.conv2.bias = torch.nn.Parameter(original_m.standalone.conv.bias.detach()) - for is_name in [True, False]: - if is_name: - prepare_config = { - "standalone_module_name": [("standalone", None, interface_config)] - } - else: - prepare_config = { - "standalone_module_class": [(StandaloneModule, None, interface_config)] - } - + qconfig_dict = {"": default_qconfig} + config_name = {"standalone_module_name": [("standalone", None, None)]} + config_class = {"standalone_module_class": [(StandaloneModule, None, None)]} + for prepare_config in [config_name, config_class]: original_m_copy = copy.deepcopy(original_m) original_ref_m_copy = copy.deepcopy(original_ref_m) - - qconfig_dict = {"": default_qconfig} # check prepared model m = prepare_fx( original_m_copy, qconfig_dict, prepare_custom_config_dict=prepare_config) # calibration m(data) - self.checkGraphModuleNodes(m, expected_node_occurrence=prepare_count_check) - self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_prepare_count_check) + # input and output of first conv, observer for standalone module + # will be inserted in the standalone module itself + count_check = { + ns.call_module(torch.quantization.MinMaxObserver): 2 + } + self.checkGraphModuleNodes(m, expected_node_occurrence=count_check) + # for input and output of conv in the standalone module + count_check = { + ns.call_module(torch.quantization.MinMaxObserver): 2 + } + self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=count_check) # check converted/quantized model m = convert_fx(m) - self.checkGraphModuleNodes(m, expected_node_occurrence=convert_count_check) - self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_convert_count_check) + count_check = { + ns.call_function(torch.quantize_per_tensor) : 1, + ns.call_module(nnq.Conv2d) : 1, + ns.call_method('dequantize') : 1, + } + self.checkGraphModuleNodes(m, expected_node_occurrence=count_check) + count_check = { + # standalone module will take float as input and output + # so we'll see quantize and dequantize in the modoule + ns.call_function(torch.quantize_per_tensor) : 1, + ns.call_module(nnq.Conv2d): 1, + ns.call_method('dequantize') : 1, + } + self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=count_check) res = m(data) # quantize the reference model @@ -654,76 +658,6 @@ def forward(self, x): ref_res = ref_m(data) self.assertEqual(res, ref_res) - def test_standalone_module_float_interface(self): - float_interface_config = { - "input_quantized_idxs": [], # float input - "output_quantized_idxs": [], # float output - } - interface_config = float_interface_config - # input and output of first conv, observer for standalone module - # will be inserted in the standalone module itself - prepare_count_check = { - ns.call_module(torch.quantization.MinMaxObserver): 2 - } - # for input and output of conv in the standalone module - standalone_prepare_count_check = { - ns.call_module(torch.quantization.MinMaxObserver): 2 - } - convert_count_check = { - ns.call_function(torch.quantize_per_tensor) : 1, - ns.call_module(nnq.Conv2d) : 1, - ns.call_method("dequantize") : 1, - } - standalone_convert_count_check = { - # standalone module will take float as input and output - # so we'll see quantize and dequantize in the modoule - ns.call_function(torch.quantize_per_tensor) : 1, - ns.call_module(nnq.Conv2d): 1, - ns.call_method("dequantize") : 1, - } - self._test_standalone_module( - interface_config, - prepare_count_check, - standalone_prepare_count_check, - convert_count_check, - standalone_convert_count_check) - - def test_standalone_module_quantized_interface(self): - quantized_interface_config = { - "input_quantized_idxs": [0], # quantized input - "output_quantized_idxs": [0], # quantized output - } - interface_config = quantized_interface_config - # observer for input and output of first conv - prepare_count_check = { - ns.call_module(torch.quantization.MinMaxObserver): 2 - } - # for output of conv in the standalone module - standalone_prepare_count_check = { - ns.call_module(torch.quantization.MinMaxObserver): 1 - } - convert_count_check = { - # quantizing input for conv - ns.call_function(torch.quantize_per_tensor) : 1, - ns.call_module(nnq.Conv2d) : 1, - # dequantizing output of standalone module - ns.call_method("dequantize") : 1, - } - standalone_convert_count_check = { - # quantization of input happens in parent module - # quantization of output happens in the quantized conv module - ns.call_function(torch.quantize_per_tensor) : 0, - ns.call_module(nnq.Conv2d): 1, - # dequantization for output happens in parent module - ns.call_method("dequantize") : 0, - } - self._test_standalone_module( - interface_config, - prepare_count_check, - standalone_prepare_count_check, - convert_count_check, - standalone_convert_count_check) - @skipIfNoFBGEMM def test_qconfig_none(self): class M(torch.nn.Module): diff --git a/torch/quantization/fx/observed_module.py b/torch/quantization/fx/observed_module.py index 808a3b36fb4a..a95bc184fa10 100644 --- a/torch/quantization/fx/observed_module.py +++ b/torch/quantization/fx/observed_module.py @@ -2,11 +2,11 @@ import copy from torch.fx import GraphModule # type: ignore from torch.fx.graph import Graph -from typing import Union, Dict, Any, List +from typing import Union, Dict, Any class ObservedGraphModule(GraphModule): - def get_preserved_attr_names(self) -> List[str]: + def get_preserved_attr_names(self): return ['_activation_post_process_map', '_patterns', '_qconfig_map', @@ -35,12 +35,6 @@ def is_observed_module(module: Any) -> bool: return isinstance(module, ObservedGraphModule) class ObservedStandaloneGraphModule(ObservedGraphModule): - def get_preserved_attr_names(self) -> List[str] : - return super().get_preserved_attr_names() + [ - "_standalone_module_input_quantized_idxs", - "_standalone_module_output_quantized_idxs" - ] - def __deepcopy__(self, memo): fake_mod = torch.nn.Module() fake_mod.__dict__ = copy.deepcopy(self.__dict__) diff --git a/torch/quantization/fx/quantization_patterns.py b/torch/quantization/fx/quantization_patterns.py index ed2f7e35659c..a1e601332d4a 100644 --- a/torch/quantization/fx/quantization_patterns.py +++ b/torch/quantization/fx/quantization_patterns.py @@ -753,10 +753,10 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, qconfig = quantizer.qconfig_map[node.name] convert = torch.quantization.quantize_fx._convert_standalone_module_fx # type: ignore observed_standalone_module = quantizer.modules[node.target] - input_quantized_idxs = observed_standalone_module._standalone_module_input_quantized_idxs quantized_standalone_module = convert(observed_standalone_module, debug=debug) parent_name, name = _parent_name(node.target) # update the modules dict setattr(quantizer.modules[parent_name], name, quantized_standalone_module) quantizer.modules[node.target] = quantized_standalone_module - return quantizer.quantized_graph.node_copy(node, load_arg(quantized=input_quantized_idxs)) + # standalone module takes float input + return quantizer.quantized_graph.node_copy(node, load_arg(quantized=False)) diff --git a/torch/quantization/fx/quantize.py b/torch/quantization/fx/quantize.py index d821f9610b7f..af9496a66a63 100644 --- a/torch/quantization/fx/quantize.py +++ b/torch/quantization/fx/quantize.py @@ -102,15 +102,14 @@ def insert_observer( 'call_module', observer_name, (load_arg(node),), {}) observed_node_names_set.add(node.name) -def maybe_insert_observer_for_special_module( +def insert_observer_for_special_module( quantize_handler: QuantizeHandler, modules: Dict[str, torch.nn.Module], - prepare_custom_config_dict: Any, qconfig: Any, node: Node) -> Optional[List[int]]: + prepare_custom_config_dict: Any, qconfig: Any, node: Node): """ Insert observer for custom module and standalone module Returns: standalone_module_input_idxs: the indexs for inputs that needs to be observed by parent module """ assert modules is not None - standalone_module_input_idxs = None if isinstance(quantize_handler, CustomModuleQuantizeHandler): custom_module = modules[node.target] # type: ignore custom_module_class_mapping = prepare_custom_config_dict.get( @@ -130,22 +129,19 @@ def maybe_insert_observer_for_special_module( class_config_map = {x[0]: (x[1], x[2]) for x in standalone_module_class_configs} name_config_map = {x[0]: (x[1], x[2]) for x in standalone_module_name_configs} config = class_config_map.get(type(standalone_module), (None, None)) - config = name_config_map.get(node.target, config) - sm_qconfig_dict = {"": qconfig} if config[0] is None else config[0] - sm_prepare_config_dict = {} if config[1] is None else config[1] + config = name_config_map.get(node.target, (None, None)) + standalone_module_qconfig_dict = {"": qconfig} if config[0] is None else config[0] + standalone_prepare_config_dict = {} if config[1] is None else config[1] prepare = \ torch.quantization.quantize_fx._prepare_standalone_module_fx # type: ignore observed_standalone_module = \ - prepare(standalone_module, sm_qconfig_dict, sm_prepare_config_dict) - standalone_module_input_idxs = observed_standalone_module.\ - _standalone_module_input_quantized_idxs + prepare(standalone_module, standalone_module_qconfig_dict, standalone_prepare_config_dict) observed_standalone_module = mark_observed_standalone_module( observed_standalone_module) parent_name, name = _parent_name(node.target) setattr(modules[parent_name], name, observed_standalone_module) modules[node.target] = observed_standalone_module # type: ignore - return standalone_module_input_idxs def insert_observer_for_output_of_the_node( node: Node, @@ -159,8 +155,7 @@ def insert_observer_for_output_of_the_node( observed_graph: Graph, load_arg: Callable, observed_node_names_set: Set[str], - matched_nodes: Optional[List[Node]], - standalone_module_input_idxs: Optional[List[int]]): + matched_nodes: Optional[List[Node]]): """ Insert observer/fake_quantize module for output of the observed module if needed """ @@ -220,11 +215,8 @@ def input_is_observed(arg): observed_node_names_set.add(node.name) elif isinstance(quantize_handler, StandaloneModuleQuantizeHandler): - assert node.op == "call_module" - output_is_quantized = 0 in \ - modules[node.target]._standalone_module_output_quantized_idxs # type: ignore - if output_is_quantized: - observed_node_names_set.add(node.name) + # output is observed in the standalone module + return elif (quantize_handler.all_node_args and input_output_observed(quantize_handler)): # observer for outputs @@ -234,16 +226,6 @@ def input_is_observed(arg): activation_post_process_map, env, observed_graph, load_arg, observed_node_names_set) - # insert observer for input of standalone module - if standalone_module_input_idxs is not None: - for idx in standalone_module_input_idxs: - if node.args[idx].name not in observed_node_names_set: # type: ignore - new_observer = qconfig.activation() - insert_observer( - node, new_observer, model, - activation_post_process_map, env, observed_graph, - load_arg, observed_node_names_set) - def insert_observer_for_input_arg_of_observed_node( node: Node, observed_node_names_set: Set[str], quants: Dict[str, Tuple[DefaultQuantizeHandler, Callable]], @@ -391,19 +373,10 @@ def _prepare(self, model: GraphModule, qconfig_dict: Any, """ standalone_module means it a submodule that is not inlined in parent module, and will be quantized separately as one unit. - How the standalone module is observed is specified by `input_quantized_idxs` and - `output_quantized_idxs` in the prepare_custom_config for the standalone module + When we are preparing a standalone module: + both input and output are observed in prepared standalone module Returns: model(GraphModule): prepared standalone module - attributes: - _standalone_module_input_quantized_idxs(List[Int]): a list of - indexes for the graph input that is expected to be quantized, - same as input_quantized_idxs configuration provided - for the standalone module - _standalone_module_output_quantized_idxs(List[Int]): a list of - indexs for the graph output that is quantized - same as input_quantized_idxs configuration provided - for the standalone module """ if prepare_custom_config_dict is None: prepare_custom_config_dict = {} @@ -457,6 +430,8 @@ def _prepare(self, model: GraphModule, qconfig_dict: Any, def load_arg(a): return map_arg(a, lambda node: env[node.name]) + # indexes for the inputs that needs to be observed + standalone_module_observed_input_idxs: List[int] = [] graph_inputs = [] for node in model.graph.nodes: if node.op == 'placeholder': @@ -512,15 +487,14 @@ def load_arg(a): # parent if qconfig is not None: assert obj is not None - standalone_module_input_idxs = \ - maybe_insert_observer_for_special_module( - obj, self.modules, prepare_custom_config_dict, qconfig, - node) + insert_observer_for_special_module( + obj, self.modules, prepare_custom_config_dict, qconfig, + node) insert_observer_for_output_of_the_node( node, obj, qconfig, self.modules, model, pattern, self.activation_post_process_map, env, observed_graph, load_arg, observed_node_names_set, - matched_nodes, standalone_module_input_idxs) + matched_nodes) else: env[node.name] = observed_graph.node_copy(node, load_arg) @@ -542,19 +516,6 @@ def load_arg(a): model = GraphModule(model, observed_graph) self.save_state(model) model = mark_observed_module(model) - if is_standalone_module: - assert result_node is not None - assert isinstance(result_node.args[0], Node), \ - "standalone module only supports returning simple value currently"\ - "(not tuple, dict etc.)" - # indicator for whether output is observed or not. - # This used for correctly quantize standalone modules - output_is_observed = \ - result_node.args[0].name in observed_node_names_set - # these inputs are observed in parent - model._standalone_module_input_quantized_idxs = \ - input_quantized_idxs - model._standalone_module_output_quantized_idxs = output_quantized_idxs return model def save_state(self, observed: GraphModule) -> None: @@ -608,10 +569,8 @@ def _convert(self, model: GraphModule, debug: bool = False, """ standalone_module means it a submodule that is not inlined in parent module, and will be quantized separately as one unit. - Returns a quantized standalone module, whether input/output is quantized is - specified by prepare_custom_config_dict, with - input_quantized_idxs, output_quantized_idxs, please - see docs for prepare_fx for details + Returns a quantized standalone module which accepts float input + and produces float output. """ if convert_custom_config_dict is None: convert_custom_config_dict = {} @@ -668,50 +627,36 @@ def load_x(n: Node) -> Node: else: return env[n.name] - def load_arg(quantized: Optional[Union[List[int], bool, Tuple[int, ...]]] + def load_arg(quantized: Optional[Union[List[Any], bool, Tuple[Any, ...]]] ) -> Callable[[Node], Argument]: """ Input: quantized, which can be None, list, boolean or tuple - - if quantized is None, then we'll load the node as long as it - exists - - if quantized is a boolean, then all args will be - quantized/not quantized - - if quantized is an empty list or tuple, then it is the same as load_arg(quantized=False) - if quantized is a list or tuple, then arg should be a list and the args with corresponding indexes will be quantized + - if quantized is a boolean, then all args will be + quantized/not quantized + - if quantized is None, then we'll load the node as long as it + exists Output: fn which takes arg_or_args, and loads them from the corresponding environment depending on the value of quantized. """ assert quantized is None or \ isinstance(quantized, (tuple, list, bool)), type(quantized) - if isinstance(quantized, (tuple, list)) and len(quantized) == 0: - # empty tuple or list means nothing is quantized - quantized = False def load_arg_impl(arg_or_args): - # we'll update the format of `quantized` - # to better match arg_or_args - updated_quantized: Optional[Union[List[int], bool, Tuple[int, ...]]] = quantized - - if isinstance(quantized, (tuple, list)) and \ - len(quantized) == 1 and isinstance(arg_or_args, Node): - # when argument is one Node instead of tuple, we just need to check - # 0 is in the quantized list - updated_quantized = 0 in quantized - - if updated_quantized is None: + if quantized is None: return map_arg(arg_or_args, load_x) - if isinstance(updated_quantized, bool): + if isinstance(quantized, bool): return map_arg( arg_or_args, - load_quantized if updated_quantized else load_non_quantized) - elif isinstance(updated_quantized, (tuple, list)): + load_quantized if quantized else load_non_quantized) + elif isinstance(quantized, (tuple, list)): assert isinstance(arg_or_args, (tuple, list)), arg_or_args loaded_args = [] # for now, we only support quantizing positional arguments for i, a in enumerate(arg_or_args): - if i in updated_quantized: + if i in quantized: loaded_args.append(map_arg(a, load_quantized)) else: loaded_args.append(map_arg(a, load_non_quantized)) @@ -745,10 +690,10 @@ def node_arg_is_quantized(node_arg: Any) -> bool: def is_output_quantized(node: Node, obj: QuantizeHandler) -> bool: """ Check if output node is quantized or not """ assert self.modules is not None - # by default the output for a quantizable node is expected to be quantized + # by default the output is expected to be quantized quantized = True - # Need to get correct quantized/non-quantized state forn the output + # Need to get correct quantized/non-quantized state for the output # of CopyNode if type(obj) in [ CopyNode, @@ -805,7 +750,7 @@ def insert_quantize_node(node: Node) -> None: "output_quantized_idxs", []) for node in model.graph.nodes: - if node.op == "output": + if node.op == 'output': cur_output_node_idx = output_node_seen_cnt output_node_seen_cnt += 1 if cur_output_node_idx in output_quantized_idxs: @@ -830,19 +775,12 @@ def insert_quantize_node(node: Node) -> None: quantized = False else: assert obj is not None - # We will get whether the output is quantized or not before - # convert for standalone module and after convert - # for non-standalone module, since _standalone_module_output_quantized_idxs - # is only available in observed standalone module - if is_observed_standalone_module_node: - out_quant_idxs = self.modules[node.target]._standalone_module_output_quantized_idxs - assert len(out_quant_idxs) <= 1, "Currently standalone only support one output" - quantized = 0 in out_quant_idxs - result = obj.convert( self, node, load_arg, debug=debug, convert_custom_config_dict=convert_custom_config_dict) - if not is_observed_standalone_module_node: + if is_observed_standalone_module_node: + quantized = False + else: quantized = is_output_quantized(node, obj) if quantized: diff --git a/torch/quantization/quantize_fx.py b/torch/quantization/quantize_fx.py index 89ba877ffe78..cba104b8f783 100644 --- a/torch/quantization/quantize_fx.py +++ b/torch/quantization/quantize_fx.py @@ -107,20 +107,8 @@ def _prepare_standalone_module_fx( standalone_module means it a submodule that is not inlined in parent module, and will be quantized separately as one unit. - How the standalone module is observed is specified by `input_quantized_idxs` and - `output_quantized_idxs` in the prepare_custom_config for the standalone module - - Returns: - model(GraphModule): prepared standalone module - attributes: - _standalone_module_input_quantized_idxs(List[Int]): a list of - indexes for the graph input that is expected to be quantized, - same as input_quantized_idxs configuration provided - for the standalone module - _standalone_module_output_quantized_idxs(List[Int]): a list of - indexs for the graph output that is quantized - same as input_quantized_idxs configuration provided - for the standalone module + Both input and output of the module are observed in the + standalone module. """ return _prepare_fx(model, qconfig_dict, prepare_custom_config_dict, is_standalone_module=True) @@ -390,9 +378,8 @@ def _convert_standalone_module_fx( r""" [Internal use only] Convert a model produced by :func:`~torch.quantization.prepare_standalone_module_fx` and convert it to a quantized model - Returns a quantized standalone module, whether input/output is quantized is - specified by prepare_custom_config_dict, with - input_quantized_idxs, output_quantized_idxs, please - see docs for prepare_fx for details + Return: + A quantized standalone module which accepts float input + and produces float output. """ return _convert_fx(graph_module, debug, convert_custom_config_dict, is_standalone_module=True)