From d7e838467aa467c1203b27c2c4d7d41e02b5ef39 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 17 Nov 2020 15:00:10 -0800 Subject: [PATCH] [qunat][graphmode][fx] Embedding/EmbeddingBag works in static quant qconfig (#48062) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/48062 When Embedding/EmbeddingBag are configured with static quant we'll skip inserting observer for them in the graph and keep the op unchanged and print a warning. This also aligns with eager mode behavior as well. We'll enforce this behavior for other ops that only supports dynamic/weight_only quant but not static quant as well. We used a global variable `DEFAULT_NOT_OBSERVED_QUANTIZE_HANDLER`, this is not exposed to user right now, we can add that later if needed. Test Plan: Imported from OSS Reviewed By: supriyar Differential Revision: D25007537 fbshipit-source-id: 6ab9e025269b44bbfd0d6dd5bb9f95fe3ca9dead --- test/quantization/test_quantize_fx.py | 26 ++++++++++--------- torch/quantization/fx/pattern_utils.py | 12 +++++++++ .../quantization/fx/quantization_patterns.py | 12 +++++++-- torch/quantization/fx/quantize.py | 8 +++--- 4 files changed, 41 insertions(+), 17 deletions(-) diff --git a/test/quantization/test_quantize_fx.py b/test/quantization/test_quantize_fx.py index e3ad5ddde680..3653bc9757f9 100644 --- a/test/quantization/test_quantize_fx.py +++ b/test/quantization/test_quantize_fx.py @@ -1948,7 +1948,8 @@ def forward(self, indices): quantized_node = ns.call_module(nnq.Embedding) configs = [ (float_qparams_dynamic_qconfig, ns.call_module(nnq.Embedding)), - (None, ns.call_module(nn.Embedding)) + (None, ns.call_module(nn.Embedding)), + (default_qconfig, ns.call_module(nn.Embedding)), ] for qconfig, node in configs: @@ -1991,17 +1992,18 @@ def forward(self, indices, offsets): custom_qconfig=float_qparams_qconfig ) - # check it works in None qconfig - qconfig_dict = {"": None} - m = M().eval() - m = prepare_fx(model, qconfig_dict) - self.checkGraphModuleNodes(m, expected_node_occurrence={ - ns.call_module(torch.quantization.MinMaxObserver): 0 - }) - m = convert_fx(m) - self.checkGraphModuleNodes(m, expected_node=ns.call_module(nn.EmbeddingBag)) - # make sure it runs - m(*inputs) + # check it works in None and static qconfig + for qconfig in [None, default_qconfig]: + qconfig_dict = {"": default_qconfig} + m = M().eval() + m = prepare_fx(model, qconfig_dict) + self.checkGraphModuleNodes(m, expected_node_occurrence={ + ns.call_module(torch.quantization.MinMaxObserver): 0 + }) + m = convert_fx(m) + self.checkGraphModuleNodes(m, expected_node=ns.call_module(nn.EmbeddingBag)) + # make sure it runs + m(*inputs) class TestQuantizeFxModels(QuantizationTestCase): def _test_model_impl( diff --git a/torch/quantization/fx/pattern_utils.py b/torch/quantization/fx/pattern_utils.py index ccbd7cc2a2c4..753041e7b08e 100644 --- a/torch/quantization/fx/pattern_utils.py +++ b/torch/quantization/fx/pattern_utils.py @@ -37,6 +37,18 @@ def get_default_quant_patterns(): def get_default_output_activation_post_process_map(): return DEFAULT_OUTPUT_ACTIVATION_POST_PROCESS_MAP +# a set of QuantizeHandler classes that are not observed +# we'll skip inserting observers for input and output for these QuantizeHandlers +# used for ops that only supports dynamic/weight only quantization +DEFAULT_NOT_OBSERVED_QUANTIZE_HANDLER = set() +def mark_input_output_not_observed(): + def insert(fn): + DEFAULT_NOT_OBSERVED_QUANTIZE_HANDLER.add(fn) + return fn + return insert + +def input_output_observed(qh): + return type(qh) not in DEFAULT_NOT_OBSERVED_QUANTIZE_HANDLER # Example use of register pattern function: # @register_fusion_pattern(torch.nn.ReLU, (torch.nn.BatchNorm2d, torch.nn.Conv2d))) diff --git a/torch/quantization/fx/quantization_patterns.py b/torch/quantization/fx/quantization_patterns.py index 98263e80d3fb..62eba0069e39 100644 --- a/torch/quantization/fx/quantization_patterns.py +++ b/torch/quantization/fx/quantization_patterns.py @@ -15,6 +15,7 @@ ) from .pattern_utils import ( register_quant_pattern, + mark_input_output_not_observed, ) from .utils import ( _parent_name, @@ -30,6 +31,7 @@ from abc import ABC, abstractmethod import operator +import warnings # ------------------------- # Pattern Registrations @@ -418,6 +420,7 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_ @register_quant_pattern(torch.nn.Embedding) @register_quant_pattern(torch.nn.EmbeddingBag) +@mark_input_output_not_observed() class Embedding(QuantizeHandler): def __init__(self, quantizer, node): super().__init__(quantizer, node) @@ -437,8 +440,13 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_ emb = quantizer.modules[emb_node.target] qconfig = quantizer.qconfig_map[node.name] dtypes = get_qconfig_dtypes(qconfig) - assert dtypes in supported_dtypes, "qconfig dtype pair not supported:" \ - " {}, supported dtypes are: {}".format(dtypes, supported_dtypes) + if dtypes not in supported_dtypes: + warnings.warn( + "dtype combination: {} is not " + "supported by Embedding/EmbeddingBag, " + "supported dtype combinations are: {}".format(dtypes, supported_dtypes)) + return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None)) + qemb = get_static_quant_module_class(type(emb)) quantized = qemb.from_float(emb) parent_name, name = _parent_name(emb_node.target) diff --git a/torch/quantization/fx/quantize.py b/torch/quantization/fx/quantize.py index 0777c9bf8445..ea802b904db3 100644 --- a/torch/quantization/fx/quantize.py +++ b/torch/quantization/fx/quantize.py @@ -32,6 +32,7 @@ is_match, get_default_quant_patterns, get_default_output_activation_post_process_map, + input_output_observed, ) from .observed_module import ( @@ -479,7 +480,7 @@ def input_is_observed(arg): output_is_observed = self.modules[node.target]._output_is_observed if output_is_observed: observed_node_names_set.add(node.name) - elif quantize_handler.all_node_args: + elif quantize_handler.all_node_args and input_output_observed(quantize_handler): # observer for outputs new_observer = qconfig.activation() insert_observer(node, new_observer) @@ -710,7 +711,8 @@ def is_output_quantized(node): 'CopyNode of type ' + node.op + ' is not handled' quantized = is_quantized(node.args[0]) - if not activation_is_statically_quantized(qconfig): + if not activation_is_statically_quantized(qconfig) or \ + not input_output_observed(obj): quantized = False return quantized @@ -975,7 +977,7 @@ def visit_arg(arg): # don't attach observer/fake_quant for CopyNode if isinstance(quantize_handler, CopyNode): qconfig = None - if root_node is node: + if root_node is node and input_output_observed(quantize_handler): # matched_nodes[-1] is the first op in the sequence and # matched_nodes[0] is the last op in the sequence # inputs