diff --git a/test/quantization/test_quantize_fx.py b/test/quantization/test_quantize_fx.py index f5f243a1e649..7c6c548f2594 100644 --- a/test/quantization/test_quantize_fx.py +++ b/test/quantization/test_quantize_fx.py @@ -28,6 +28,7 @@ default_qconfig, default_dynamic_qconfig, default_qat_qconfig, + per_channel_dynamic_qconfig, float16_dynamic_qconfig, float_qparams_weight_only_qconfig, get_default_qconfig, @@ -36,6 +37,7 @@ prepare, prepare_qat, convert, + quantize_dynamic, default_placeholder_observer, PerChannelMinMaxObserver, QConfigDynamic, @@ -57,6 +59,8 @@ from torch.testing._internal.common_quantization import ( LinearModelWithSubmodule, ResNetBase, + RNNDynamicModel, + RNNCellDynamicModel, ) from torch.testing._internal.common_quantized import ( @@ -2107,6 +2111,49 @@ def forward(self, indices, offsets): # make sure it runs m(*inputs) + def _test_rnn_impl(self, qconfigs, M, module_type_strs, module_types, sample_input): + options = itertools.product(qconfigs, module_type_strs) + for qconfig, module_type_str in options: + model_eager = M(module_type_str).eval() + model_graph = copy.deepcopy(model_eager) + if torch.backends.quantized.engine == 'qnnpack' and \ + qconfig is float16_dynamic_qconfig: + continue + # fp16 dynamic quant is not supported for qnnpack + + eager_qconfig_dict = {x : qconfig for x in module_types} + model_eager = quantize_dynamic(model_eager, qconfig_spec=eager_qconfig_dict) + + graph_qconfig_dict = { + "object_type": [ + (x, qconfig) for x in module_types + ] + } + model_graph = prepare_fx(model_graph, graph_qconfig_dict) + model_graph = convert_fx(model_graph) + self.assertEqual(model_eager(sample_input), model_graph(sample_input)) + self.checkScriptable(model_graph, [[sample_input]], True) + + def test_rnn_cell(self): + qconfigs = [per_channel_dynamic_qconfig, default_dynamic_qconfig, float16_dynamic_qconfig] + module_type_strs = ['LSTMCell', 'GRUCell', 'RNNTanh', 'RNNReLU'] + module_types = [torch.nn.LSTMCell, torch.nn.GRUCell, torch.nn.RNNCell] + sample_input = torch.tensor([[100, -155], + [-155, 100], + [100, -155]], dtype=torch.float) + self._test_rnn_impl(qconfigs, RNNCellDynamicModel, module_type_strs, module_types, sample_input) + + def test_rnn(self): + qconfigs = [per_channel_dynamic_qconfig, default_dynamic_qconfig, float16_dynamic_qconfig] + module_type_strs = ['LSTM'] + module_types = [torch.nn.LSTM] + niter = 10 + sample_input = torch.tensor([[100, -155], + [-155, 100], + [100, -155]], dtype=torch.float).unsqueeze(0).repeat(niter, 1, 1) + self._test_rnn_impl(qconfigs, RNNDynamicModel, module_type_strs, module_types, sample_input) + + class TestQuantizeFxModels(QuantizationTestCase): def _test_model_impl( self, mode, name, model, eager_quantizable_model, diff --git a/torch/quantization/fx/quantization_patterns.py b/torch/quantization/fx/quantization_patterns.py index 176cd7603286..73590ad60904 100644 --- a/torch/quantization/fx/quantization_patterns.py +++ b/torch/quantization/fx/quantization_patterns.py @@ -11,6 +11,7 @@ from ..quantization_mappings import ( get_static_quant_module_class, + get_dynamic_quant_module_class, get_quantized_operator, ) from ..utils import ( @@ -471,7 +472,6 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, ] assert node.op == 'call_module' emb_node = node - emb = quantizer.modules[emb_node.target] qconfig = quantizer.qconfig_map[node.name] dtypes = get_qconfig_dtypes(qconfig) if dtypes not in supported_dtypes: @@ -481,6 +481,7 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, "supported dtype combinations are: {}".format(dtypes, supported_dtypes)) return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None)) + emb = quantizer.modules[emb_node.target] qemb = get_static_quant_module_class(type(emb)) quantized = qemb.from_float(emb) parent_name, name = _parent_name(emb_node.target) @@ -491,6 +492,48 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, load_arg(quantized=False)(emb_node.args), load_arg(quantized=False)(emb_node.kwargs)) +# TODO (maybe): merge with embedding quantize handler +@register_quant_pattern(torch.nn.GRUCell) +@register_quant_pattern(torch.nn.LSTMCell) +@register_quant_pattern(torch.nn.RNNCell) +@register_quant_pattern(torch.nn.LSTM) +@mark_input_output_not_observed() +class RNNDynamic(QuantizeHandler): + def __init__(self, quantizer: QuantizerCls, node: Node): + super().__init__(quantizer, node) + + def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, + debug: bool = False, + convert_custom_config_dict: Dict[str, Any] = None) -> Node: + # Supported combinations are: + # quant_type | activation | weight | activation_compute_type + # dynamic | float32 | qint8 | quint8 + # dynamic | float16 | float16 | None + # tuple (activation_dtype, weight_dtype, compute_dtype) + supported_dtypes = [ + (torch.float32, torch.qint8, torch.quint8), + (torch.float16, torch.float16, None), + ] + assert node.op == 'call_module' + qconfig = quantizer.qconfig_map[node.name] + dtypes = get_qconfig_dtypes(qconfig) + if dtypes not in supported_dtypes: + warnings.warn( + "dtype combination: {} is not " + "supported by Embedding/EmbeddingBag, " + "supported dtype combinations are: {}".format(dtypes, supported_dtypes)) + return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None)) + + module = quantizer.modules[node.target] + qmodule_cls = get_dynamic_quant_module_class(type(module)) + qmodule = qmodule_cls.from_float(module) + parent_name, name = _parent_name(node.target) + setattr(quantizer.modules[parent_name], name, qmodule) + return quantizer.quantized_graph.create_node( + 'call_module', + node.target, + load_arg(quantized=False)(node.args), + load_arg(quantized=False)(node.kwargs)) ARGS_TO_SKIP = { torch._ops.ops.quantized.hardswish: ['inplace'], diff --git a/torch/quantization/quantization_mappings.py b/torch/quantization/quantization_mappings.py index 88d264b1ccf3..c965de07deb7 100644 --- a/torch/quantization/quantization_mappings.py +++ b/torch/quantization/quantization_mappings.py @@ -124,6 +124,19 @@ def get_static_quant_module_class(float_module_class, additional_static_quant_ma " does not have a corresponding quantized module class" return static_quant_module_class +def get_dynamic_quant_module_class(float_module_class, additional_dynamic_quant_mapping=None): + r"""n Get the dynamically quantized module class corresponding to + the floating point module class + """ + if additional_dynamic_quant_mapping is None: + additional_dynamic_quant_mapping = {} + all_mappings = get_combined_dict(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS, additional_dynamic_quant_mapping) + dynamic_quant_module_class = all_mappings.get(float_module_class, None) + assert dynamic_quant_module_class is not None, \ + "Floating point module class {}".format(str(float_module_class)) + \ + " does not have a corresponding quantized module class" + return dynamic_quant_module_class + def get_default_qat_module_mappings(): ''' Get default module mapping for quantization aware training '''