From 576fa09157de3e83016cf81cf23c6334f9a69ef1 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 18 Nov 2020 18:19:02 -0800 Subject: [PATCH] [quant][fix] Fix quant type classification for float_qparam qconfig (#48069) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/48069 also renamed float_qparam_dynamic_qconfig to float_qparam_weight_only_qconfig It's not used in user code yet so we only need to update the tests. Test Plan: Imported from OSS Reviewed By: supriyar Differential Revision: D25010175 fbshipit-source-id: caa3eaa5358a8bc5c808bf5f64e6ebff3e0b61e8 --- test/quantization/test_quantize.py | 4 ++-- test/quantization/test_quantize_fx.py | 8 ++++---- torch/nn/quantized/modules/embedding_ops.py | 8 ++++---- torch/quantization/__init__.py | 3 ++- torch/quantization/fx/quantization_patterns.py | 10 +++++----- torch/quantization/observer.py | 3 ++- torch/quantization/qconfig.py | 7 +++++-- torch/quantization/quantize.py | 6 +++--- torch/testing/_internal/common_quantization.py | 4 ++-- 9 files changed, 29 insertions(+), 24 deletions(-) diff --git a/test/quantization/test_quantize.py b/test/quantization/test_quantize.py index 96b4546f5fad..ee4a114dcee0 100644 --- a/test/quantization/test_quantize.py +++ b/test/quantization/test_quantize.py @@ -22,7 +22,7 @@ default_dynamic_qconfig, per_channel_dynamic_qconfig, float16_dynamic_qconfig, - float_qparams_dynamic_qconfig, + float_qparams_weight_only_qconfig, PerChannelMinMaxObserver, QConfigDynamic, default_dynamic_quant_observer, @@ -521,7 +521,7 @@ def test_quantized_embedding(self): model = EmbeddingModule().eval() indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3]) weights = torch.randn(10, 12, dtype=torch.float32) - model.qconfig = float_qparams_dynamic_qconfig + model.qconfig = float_qparams_weight_only_qconfig prepare(model, inplace=True) convert(model, inplace=True) self.assertTrue('QuantizedEmbedding' in str(model)) diff --git a/test/quantization/test_quantize_fx.py b/test/quantization/test_quantize_fx.py index 3653bc9757f9..d4116c2061ef 100644 --- a/test/quantization/test_quantize_fx.py +++ b/test/quantization/test_quantize_fx.py @@ -21,16 +21,16 @@ quant_type_to_str, default_qconfig, default_dynamic_qconfig, - default_dynamic_quant_observer, default_qat_qconfig, float16_dynamic_qconfig, - float_qparams_dynamic_qconfig, + float_qparams_weight_only_qconfig, get_default_qconfig, get_default_qat_qconfig, fuse_modules, prepare, prepare_qat, convert, + default_placeholder_observer, PerChannelMinMaxObserver, QConfigDynamic, FixedQParamsFakeQuantize, @@ -1947,7 +1947,7 @@ def forward(self, indices): indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3]) quantized_node = ns.call_module(nnq.Embedding) configs = [ - (float_qparams_dynamic_qconfig, ns.call_module(nnq.Embedding)), + (float_qparams_weight_only_qconfig, ns.call_module(nnq.Embedding)), (None, ns.call_module(nn.Embedding)), (default_qconfig, ns.call_module(nn.Embedding)), ] @@ -1982,7 +1982,7 @@ def forward(self, indices, offsets): float_qparams_observer = PerChannelMinMaxObserver.with_args(dtype=dtype, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0) - float_qparams_qconfig = QConfigDynamic(activation=default_dynamic_quant_observer, + float_qparams_qconfig = QConfigDynamic(activation=default_placeholder_observer, weight=float_qparams_observer) self.checkGraphModeFxOp( model, diff --git a/torch/nn/quantized/modules/embedding_ops.py b/torch/nn/quantized/modules/embedding_ops.py index 2344543f23f8..d16748b3baf7 100644 --- a/torch/nn/quantized/modules/embedding_ops.py +++ b/torch/nn/quantized/modules/embedding_ops.py @@ -144,11 +144,11 @@ def from_float(cls, mod): assert type(mod) == nn.Embedding, 'nnq.' + cls.__name__ + '.from_float only works for ' + \ nn.Embedding.__name__ assert hasattr(mod, 'qconfig'), 'Embedding input float module must have qconfig defined' - from torch.quantization.qconfig import float_qparams_dynamic_qconfig + from torch.quantization import float_qparams_weight_only_qconfig if mod.qconfig is not None and mod.qconfig.weight is not None: weight_observer = mod.qconfig.weight() else: - weight_observer = float_qparams_dynamic_qconfig.weight() + weight_observer = float_qparams_weight_only_qconfig.weight() dtype = weight_observer.dtype @@ -224,11 +224,11 @@ def from_float(cls, mod): assert type(mod) == nn.EmbeddingBag, 'nnq.' + cls.__name__ + '.from_float only works for ' + \ nn.EmbeddingBag.__name__ assert hasattr(mod, 'qconfig'), 'EmbeddingBag input float module must have qconfig defined' - from torch.quantization.qconfig import float_qparams_dynamic_qconfig + from torch.quantization.qconfig import float_qparams_weight_only_qconfig if mod.qconfig is not None and mod.qconfig.weight is not None: weight_observer = mod.qconfig.weight() else: - weight_observer = float_qparams_dynamic_qconfig.weight() + weight_observer = float_qparams_weight_only_qconfig.weight() dtype = weight_observer.dtype diff --git a/torch/quantization/__init__.py b/torch/quantization/__init__.py index 24e929b5fc8e..b2a8e542f93a 100644 --- a/torch/quantization/__init__.py +++ b/torch/quantization/__init__.py @@ -43,7 +43,7 @@ def default_eval_fn(model, calib_data): 'register_activation_post_process_hook', # Observers 'ObserverBase', 'WeightObserver', 'observer', 'default_observer', - 'default_weight_observer', + 'default_weight_observer', 'default_placeholder_observer', # FakeQuantize (for qat) 'default_fake_quant', 'default_weight_fake_quant', 'default_symmetric_fixed_qparams_fake_quant', @@ -52,6 +52,7 @@ def default_eval_fn(model, calib_data): 'default_histogram_fake_quant', # QConfig 'QConfig', 'default_qconfig', 'default_dynamic_qconfig', 'float16_dynamic_qconfig', + 'float_qparams_weight_only_qconfig', # QAT utilities 'default_qat_qconfig', 'prepare_qat', 'quantize_qat', # module transformations diff --git a/torch/quantization/fx/quantization_patterns.py b/torch/quantization/fx/quantization_patterns.py index ab62c300c8bf..38f520815a5e 100644 --- a/torch/quantization/fx/quantization_patterns.py +++ b/torch/quantization/fx/quantization_patterns.py @@ -432,13 +432,13 @@ def __init__(self, quantizer, node): def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_dict=None): # Supported combinations are: - # quant_type | activation (compute_type) | weight - # weight_only | float32 (torch.uint8) | quint8 - # weight_only | float32 (torch.uint8) | quint4x2 + # quant_type | activation | weight | activation_compute_type + # weight_only | float32 | quint8 | None + # weight_only | float32 | quint4x2 | None # tuple (activation_dtype, weight_dtype, compute_dtype) supported_dtypes = [ - (torch.float32, torch.quint8, torch.quint8), - (torch.float32, torch.quint4x2, torch.quint8), + (torch.float32, torch.quint8, None), + (torch.float32, torch.quint4x2, None), ] assert node.op == 'call_module' emb_node = node diff --git a/torch/quantization/observer.py b/torch/quantization/observer.py index fbd8168393c8..32d07c939695 100644 --- a/torch/quantization/observer.py +++ b/torch/quantization/observer.py @@ -989,7 +989,7 @@ class PlaceholderObserver(ObserverBase): custom_op_name: (temporary) specify this observer for an operator that doesn't require any observation (Can be used in Graph Mode Passes for special case ops). """ - def __init__(self, dtype=torch.float16, custom_op_name="", compute_dtype=None): + def __init__(self, dtype=torch.float32, custom_op_name="", compute_dtype=None): super(PlaceholderObserver, self).__init__(dtype=dtype) # dtype of input of the target operator, e.g. for dynamic quantization # ops, the dtype will be float32 @@ -1126,6 +1126,7 @@ def load_observer_state_dict(mod, obs_dict): # Restrict activations to be in the range (0,127) default_observer = MinMaxObserver.with_args(reduce_range=True) +default_placeholder_observer = PlaceholderObserver default_debug_observer = RecordingObserver default_weight_observer = MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric) default_histogram_observer = HistogramObserver.with_args(reduce_range=True) diff --git a/torch/quantization/qconfig.py b/torch/quantization/qconfig.py index 253abbaf4445..8da4ad6bb182 100644 --- a/torch/quantization/qconfig.py +++ b/torch/quantization/qconfig.py @@ -67,8 +67,11 @@ def __new__(cls, activation=torch.nn.Identity, weight=torch.nn.Identity): per_channel_dynamic_qconfig = QConfigDynamic(activation=default_dynamic_quant_observer, weight=default_per_channel_weight_observer) -float_qparams_dynamic_qconfig = QConfigDynamic(activation=default_dynamic_quant_observer, - weight=default_float_qparams_observer) +# TODO: this is weight only quant, change this to QConfigWeightOnly +# or remove the QConfigDynamic later +float_qparams_weight_only_qconfig = QConfigDynamic( + activation=default_placeholder_observer, + weight=default_float_qparams_observer) default_qat_qconfig = QConfig(activation=default_fake_quant, weight=default_weight_fake_quant) diff --git a/torch/quantization/quantize.py b/torch/quantization/quantize.py index 796fa8ce30b2..18fd1bcfe757 100644 --- a/torch/quantization/quantize.py +++ b/torch/quantization/quantize.py @@ -17,7 +17,7 @@ ) from .stubs import DeQuantStub, QuantWrapper -from .qconfig import default_dynamic_qconfig, float16_dynamic_qconfig, float_qparams_dynamic_qconfig +from .qconfig import default_dynamic_qconfig, float16_dynamic_qconfig, float_qparams_weight_only_qconfig def is_activation_post_process(module): return (isinstance(module, torch.quantization.ObserverBase) or @@ -352,7 +352,7 @@ def quantize_dynamic(model, qconfig_spec=None, dtype=torch.qint8, } elif dtype == torch.quint8: qconfig_spec = { - nn.EmbeddingBag : float_qparams_dynamic_qconfig, + nn.EmbeddingBag : float_qparams_weight_only_qconfig, } else: raise ValueError( @@ -363,7 +363,7 @@ def quantize_dynamic(model, qconfig_spec=None, dtype=torch.qint8, elif dtype is torch.float16: default_qconfig = float16_dynamic_qconfig elif dtype is torch.quint8: - default_qconfig = float_qparams_dynamic_qconfig + default_qconfig = float_qparams_weight_only_qconfig else: raise RuntimeError('Unknown dtype specified for quantize_dynamic: ', str(dtype)) qconfig_spec = dict(zip(qconfig_spec, itertools.repeat(default_qconfig))) diff --git a/torch/testing/_internal/common_quantization.py b/torch/testing/_internal/common_quantization.py index 0654f3249823..2e3cc16b4540 100644 --- a/torch/testing/_internal/common_quantization.py +++ b/torch/testing/_internal/common_quantization.py @@ -12,7 +12,7 @@ from torch.testing._internal.common_utils import TestCase from torch.quantization import QuantWrapper, QuantStub, DeQuantStub, \ default_qconfig, default_dynamic_qconfig, default_per_channel_qconfig, QConfig, default_observer, default_weight_observer, \ - propagate_qconfig_, convert, get_default_qconfig, quantize_dynamic_jit, quantize_jit, float_qparams_dynamic_qconfig, \ + propagate_qconfig_, convert, get_default_qconfig, quantize_dynamic_jit, quantize_jit, float_qparams_weight_only_qconfig, \ get_default_qat_qconfig, PerChannelMinMaxObserver, default_dynamic_quant_observer, QConfigDynamic, QuantType from torch.quantization.quantization_mappings import ( get_default_dynamic_quant_module_mappings, @@ -1449,7 +1449,7 @@ def __init__(self): super().__init__() self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=12) self.fc = torch.nn.Linear(5, 5) - self.emb.qconfig = float_qparams_dynamic_qconfig + self.emb.qconfig = float_qparams_weight_only_qconfig self.qconfig = default_qconfig def forward(self, indices, linear_in):