From ad64e8435747e61f1a0540fde9d346eb70e3b6f3 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 9 Dec 2020 15:29:48 -0800 Subject: [PATCH 1/3] [quant][graphmode][fx] Add support for dynamic quant for RNN and RNNCell Summary: Test Plan: python test/test_quantization.py TestQuantizeFxOps.test_rnn python test/test_quantization.py TestQuantizeFxOps.test_rnn_cell Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- test/quantization/test_quantize_fx.py | 46 +++++++++++++++++++ .../quantization/fx/quantization_patterns.py | 45 +++++++++++++++++- torch/quantization/quantization_mappings.py | 13 ++++++ 3 files changed, 103 insertions(+), 1 deletion(-) diff --git a/test/quantization/test_quantize_fx.py b/test/quantization/test_quantize_fx.py index f5f243a1e649..e01fa057f9c8 100644 --- a/test/quantization/test_quantize_fx.py +++ b/test/quantization/test_quantize_fx.py @@ -28,6 +28,7 @@ default_qconfig, default_dynamic_qconfig, default_qat_qconfig, + per_channel_dynamic_qconfig, float16_dynamic_qconfig, float_qparams_weight_only_qconfig, get_default_qconfig, @@ -36,6 +37,7 @@ prepare, prepare_qat, convert, + quantize_dynamic, default_placeholder_observer, PerChannelMinMaxObserver, QConfigDynamic, @@ -57,6 +59,8 @@ from torch.testing._internal.common_quantization import ( LinearModelWithSubmodule, ResNetBase, + RNNDynamicModel, + RNNCellDynamicModel, ) from torch.testing._internal.common_quantized import ( @@ -2107,6 +2111,48 @@ def forward(self, indices, offsets): # make sure it runs m(*inputs) + def _test_rnn_impl(self, qconfigs, M, module_type_strs, module_types, sample_input): + options = itertools.product(qconfigs, module_type_strs) + for qconfig, module_type_str in options: + model_eager = M(module_type_str).eval() + model_graph = copy.deepcopy(model_eager) + if torch.backends.quantized.engine == 'qnnpack' and \ + qconfig is float16_dynamic_qconfig: + continue + # fp16 dynamic quant is not supported for qnnpack + + eager_qconfig_dict = {x : qconfig for x in module_types} + model_eager = quantize_dynamic(model_eager, qconfig_spec=eager_qconfig_dict) + + graph_qconfig_dict = { + "object_type": [ + (x, qconfig) for x in module_types + ] + } + model_graph = prepare_fx(model_graph, graph_qconfig_dict) + model_graph = convert_fx(model_graph) + self.assertEqual(model_eager(sample_input), model_graph(sample_input)) + + def test_rnn_cell(self): + qconfigs = [per_channel_dynamic_qconfig, default_dynamic_qconfig, float16_dynamic_qconfig] + module_type_strs = ['LSTMCell', 'GRUCell', 'RNNTanh', 'RNNReLU'] + module_types = [torch.nn.LSTMCell, torch.nn.GRUCell, torch.nn.RNNCell] + sample_input = torch.tensor([[100, -155], + [-155, 100], + [100, -155]], dtype=torch.float) + self._test_rnn_impl(qconfigs, RNNCellDynamicModel, module_type_strs, module_types, sample_input) + + def test_rnn(self): + qconfigs = [per_channel_dynamic_qconfig, default_dynamic_qconfig, float16_dynamic_qconfig] + module_type_strs = ['LSTM'] + module_types = [torch.nn.LSTM] + niter = 10 + sample_input = torch.tensor([[100, -155], + [-155, 100], + [100, -155]], dtype=torch.float).unsqueeze(0).repeat(niter, 1, 1) + self._test_rnn_impl(qconfigs, RNNDynamicModel, module_type_strs, module_types, sample_input) + + class TestQuantizeFxModels(QuantizationTestCase): def _test_model_impl( self, mode, name, model, eager_quantizable_model, diff --git a/torch/quantization/fx/quantization_patterns.py b/torch/quantization/fx/quantization_patterns.py index 176cd7603286..a000d513a126 100644 --- a/torch/quantization/fx/quantization_patterns.py +++ b/torch/quantization/fx/quantization_patterns.py @@ -11,6 +11,7 @@ from ..quantization_mappings import ( get_static_quant_module_class, + get_dynamic_quant_module_class, get_quantized_operator, ) from ..utils import ( @@ -471,7 +472,6 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, ] assert node.op == 'call_module' emb_node = node - emb = quantizer.modules[emb_node.target] qconfig = quantizer.qconfig_map[node.name] dtypes = get_qconfig_dtypes(qconfig) if dtypes not in supported_dtypes: @@ -481,6 +481,7 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, "supported dtype combinations are: {}".format(dtypes, supported_dtypes)) return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None)) + emb = quantizer.modules[emb_node.target] qemb = get_static_quant_module_class(type(emb)) quantized = qemb.from_float(emb) parent_name, name = _parent_name(emb_node.target) @@ -491,6 +492,48 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, load_arg(quantized=False)(emb_node.args), load_arg(quantized=False)(emb_node.kwargs)) +# TODO: merge with embedding quantize handler +@register_quant_pattern(torch.nn.GRUCell) +@register_quant_pattern(torch.nn.LSTMCell) +@register_quant_pattern(torch.nn.RNNCell) +@register_quant_pattern(torch.nn.LSTM) +@mark_input_output_not_observed() +class RNNDynamic(QuantizeHandler): + def __init__(self, quantizer: QuantizerCls, node: Node): + super().__init__(quantizer, node) + + def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, + debug: bool = False, + convert_custom_config_dict: Dict[str, Any] = None) -> Node: + # Supported combinations are: + # quant_type | activation | weight | activation_compute_type + # dynamic | float32 | qint8 | quint8 + # dynamic | float16 | float16 | None + # tuple (activation_dtype, weight_dtype, compute_dtype) + supported_dtypes = [ + (torch.float32, torch.qint8, torch.quint8), + (torch.float16, torch.float16, None), + ] + assert node.op == 'call_module' + qconfig = quantizer.qconfig_map[node.name] + dtypes = get_qconfig_dtypes(qconfig) + if dtypes not in supported_dtypes: + warnings.warn( + "dtype combination: {} is not " + "supported by Embedding/EmbeddingBag, " + "supported dtype combinations are: {}".format(dtypes, supported_dtypes)) + return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None)) + + module = quantizer.modules[node.target] + qmodule = get_dynamic_quant_module_class(type(module)) + quantized = qmodule.from_float(module) + parent_name, name = _parent_name(node.target) + setattr(quantizer.modules[parent_name], name, quantized) + return quantizer.quantized_graph.create_node( + 'call_module', + node.target, + load_arg(quantized=False)(node.args), + load_arg(quantized=False)(node.kwargs)) ARGS_TO_SKIP = { torch._ops.ops.quantized.hardswish: ['inplace'], diff --git a/torch/quantization/quantization_mappings.py b/torch/quantization/quantization_mappings.py index 88d264b1ccf3..c965de07deb7 100644 --- a/torch/quantization/quantization_mappings.py +++ b/torch/quantization/quantization_mappings.py @@ -124,6 +124,19 @@ def get_static_quant_module_class(float_module_class, additional_static_quant_ma " does not have a corresponding quantized module class" return static_quant_module_class +def get_dynamic_quant_module_class(float_module_class, additional_dynamic_quant_mapping=None): + r"""n Get the dynamically quantized module class corresponding to + the floating point module class + """ + if additional_dynamic_quant_mapping is None: + additional_dynamic_quant_mapping = {} + all_mappings = get_combined_dict(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS, additional_dynamic_quant_mapping) + dynamic_quant_module_class = all_mappings.get(float_module_class, None) + assert dynamic_quant_module_class is not None, \ + "Floating point module class {}".format(str(float_module_class)) + \ + " does not have a corresponding quantized module class" + return dynamic_quant_module_class + def get_default_qat_module_mappings(): ''' Get default module mapping for quantization aware training ''' From 2dbfead4e771032f7be711c99cf81227d6f1ef81 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 9 Dec 2020 15:38:00 -0800 Subject: [PATCH 2/3] Update on "[quant][graphmode][fx] Add support for dynamic quant for RNN and RNNCell" Summary: Test Plan: python test/test_quantization.py TestQuantizeFxOps.test_rnn python test/test_quantization.py TestQuantizeFxOps.test_rnn_cell Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- torch/quantization/fx/quantization_patterns.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/quantization/fx/quantization_patterns.py b/torch/quantization/fx/quantization_patterns.py index a000d513a126..50ef6c33cc62 100644 --- a/torch/quantization/fx/quantization_patterns.py +++ b/torch/quantization/fx/quantization_patterns.py @@ -492,7 +492,7 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, load_arg(quantized=False)(emb_node.args), load_arg(quantized=False)(emb_node.kwargs)) -# TODO: merge with embedding quantize handler +# TODO (maybe): merge with embedding quantize handler @register_quant_pattern(torch.nn.GRUCell) @register_quant_pattern(torch.nn.LSTMCell) @register_quant_pattern(torch.nn.RNNCell) From ff4353e303e447fb46f2f33755a0590f6760c703 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Thu, 10 Dec 2020 11:21:15 -0800 Subject: [PATCH 3/3] Update on "[quant][graphmode][fx] Add support for dynamic quant for RNN and RNNCell" Summary: Test Plan: python test/test_quantization.py TestQuantizeFxOps.test_rnn python test/test_quantization.py TestQuantizeFxOps.test_rnn_cell Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D25449047](https://our.internmc.facebook.com/intern/diff/D25449047) [ghstack-poisoned] --- test/quantization/test_quantize_fx.py | 1 + torch/quantization/fx/quantization_patterns.py | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/test/quantization/test_quantize_fx.py b/test/quantization/test_quantize_fx.py index e01fa057f9c8..7c6c548f2594 100644 --- a/test/quantization/test_quantize_fx.py +++ b/test/quantization/test_quantize_fx.py @@ -2132,6 +2132,7 @@ def _test_rnn_impl(self, qconfigs, M, module_type_strs, module_types, sample_inp model_graph = prepare_fx(model_graph, graph_qconfig_dict) model_graph = convert_fx(model_graph) self.assertEqual(model_eager(sample_input), model_graph(sample_input)) + self.checkScriptable(model_graph, [[sample_input]], True) def test_rnn_cell(self): qconfigs = [per_channel_dynamic_qconfig, default_dynamic_qconfig, float16_dynamic_qconfig] diff --git a/torch/quantization/fx/quantization_patterns.py b/torch/quantization/fx/quantization_patterns.py index 50ef6c33cc62..73590ad60904 100644 --- a/torch/quantization/fx/quantization_patterns.py +++ b/torch/quantization/fx/quantization_patterns.py @@ -525,10 +525,10 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None)) module = quantizer.modules[node.target] - qmodule = get_dynamic_quant_module_class(type(module)) - quantized = qmodule.from_float(module) + qmodule_cls = get_dynamic_quant_module_class(type(module)) + qmodule = qmodule_cls.from_float(module) parent_name, name = _parent_name(node.target) - setattr(quantizer.modules[parent_name], name, quantized) + setattr(quantizer.modules[parent_name], name, qmodule) return quantizer.quantized_graph.create_node( 'call_module', node.target,