Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[quant][fix] Fix quant type classification for float_qparam qconfig #48069

Closed
wants to merge 8 commits into from
4 changes: 2 additions & 2 deletions test/quantization/test_quantize.py
Expand Up @@ -22,7 +22,7 @@
default_dynamic_qconfig,
per_channel_dynamic_qconfig,
float16_dynamic_qconfig,
float_qparams_dynamic_qconfig,
float_qparams_weight_only_qconfig,
PerChannelMinMaxObserver,
QConfigDynamic,
default_dynamic_quant_observer,
Expand Down Expand Up @@ -521,7 +521,7 @@ def test_quantized_embedding(self):
model = EmbeddingModule().eval()
indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3])
weights = torch.randn(10, 12, dtype=torch.float32)
model.qconfig = float_qparams_dynamic_qconfig
model.qconfig = float_qparams_weight_only_qconfig
prepare(model, inplace=True)
convert(model, inplace=True)
self.assertTrue('QuantizedEmbedding' in str(model))
Expand Down
8 changes: 4 additions & 4 deletions test/quantization/test_quantize_fx.py
Expand Up @@ -19,13 +19,13 @@
quant_type_to_str,
default_qconfig,
default_dynamic_qconfig,
default_dynamic_quant_observer,
default_qat_qconfig,
float16_dynamic_qconfig,
float_qparams_dynamic_qconfig,
float_qparams_weight_only_qconfig,
prepare,
prepare_qat,
convert,
default_placeholder_observer,
PerChannelMinMaxObserver,
QConfigDynamic,
FixedQParamsFakeQuantize,
Expand Down Expand Up @@ -1913,7 +1913,7 @@ def forward(self, indices):
indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3])
quantized_node = ns.call_module(nnq.Embedding)
configs = [
(float_qparams_dynamic_qconfig, ns.call_module(nnq.Embedding)),
(float_qparams_weight_only_qconfig, ns.call_module(nnq.Embedding)),
(None, ns.call_module(nn.Embedding)),
(default_qconfig, ns.call_module(nn.Embedding)),
]
Expand Down Expand Up @@ -1948,7 +1948,7 @@ def forward(self, indices, offsets):
float_qparams_observer = PerChannelMinMaxObserver.with_args(dtype=dtype,
qscheme=torch.per_channel_affine_float_qparams,
ch_axis=0)
float_qparams_qconfig = QConfigDynamic(activation=default_dynamic_quant_observer,
float_qparams_qconfig = QConfigDynamic(activation=default_placeholder_observer,
weight=float_qparams_observer)
self.checkGraphModeFxOp(
model,
Expand Down
8 changes: 4 additions & 4 deletions torch/nn/quantized/modules/embedding_ops.py
Expand Up @@ -144,11 +144,11 @@ def from_float(cls, mod):
assert type(mod) == nn.Embedding, 'nnq.' + cls.__name__ + '.from_float only works for ' + \
nn.Embedding.__name__
assert hasattr(mod, 'qconfig'), 'Embedding input float module must have qconfig defined'
from torch.quantization.qconfig import float_qparams_dynamic_qconfig
from torch.quantization import float_qparams_weight_only_qconfig
if mod.qconfig is not None and mod.qconfig.weight is not None:
weight_observer = mod.qconfig.weight()
else:
weight_observer = float_qparams_dynamic_qconfig.weight()
weight_observer = float_qparams_weight_only_qconfig.weight()

dtype = weight_observer.dtype

Expand Down Expand Up @@ -224,11 +224,11 @@ def from_float(cls, mod):
assert type(mod) == nn.EmbeddingBag, 'nnq.' + cls.__name__ + '.from_float only works for ' + \
nn.EmbeddingBag.__name__
assert hasattr(mod, 'qconfig'), 'EmbeddingBag input float module must have qconfig defined'
from torch.quantization.qconfig import float_qparams_dynamic_qconfig
from torch.quantization.qconfig import float_qparams_weight_only_qconfig
if mod.qconfig is not None and mod.qconfig.weight is not None:
weight_observer = mod.qconfig.weight()
else:
weight_observer = float_qparams_dynamic_qconfig.weight()
weight_observer = float_qparams_weight_only_qconfig.weight()

dtype = weight_observer.dtype

Expand Down
1 change: 1 addition & 0 deletions torch/quantization/__init__.py
Expand Up @@ -52,6 +52,7 @@ def default_eval_fn(model, calib_data):
'default_histogram_fake_quant',
# QConfig
'QConfig', 'default_qconfig', 'default_dynamic_qconfig', 'float16_dynamic_qconfig',
'float_qparams_weight_only_qconfig',
# QAT utilities
'default_qat_qconfig', 'prepare_qat', 'quantize_qat',
# module transformations
Expand Down
10 changes: 5 additions & 5 deletions torch/quantization/fx/quantization_patterns.py
Expand Up @@ -447,13 +447,13 @@ def __init__(self, quantizer, node):

def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_dict=None):
# Supported combinations are:
# quant_type | activation (compute_type) | weight
# weight_only | float32 (torch.uint8) | quint8
# weight_only | float32 (torch.uint8) | quint4x2
# quant_type | activation | weight | activation_compute_type
# weight_only | float32 | quint8 | None
# weight_only | float32 | quint4x2 | None
# tuple (activation_dtype, weight_dtype, compute_dtype)
supported_dtypes = [
(torch.float32, torch.quint8, torch.quint8),
(torch.float32, torch.quint4x2, torch.quint8),
(torch.float32, torch.quint8, None),
(torch.float32, torch.quint4x2, None),
]
assert node.op == 'call_module'
emb_node = node
Expand Down
3 changes: 2 additions & 1 deletion torch/quantization/observer.py
Expand Up @@ -989,7 +989,7 @@ class PlaceholderObserver(ObserverBase):
custom_op_name: (temporary) specify this observer for an operator that doesn't require any observation
(Can be used in Graph Mode Passes for special case ops).
"""
def __init__(self, dtype=torch.float16, custom_op_name="", compute_dtype=None):
def __init__(self, dtype=torch.float32, custom_op_name="", compute_dtype=None):
super(PlaceholderObserver, self).__init__(dtype=dtype)
# dtype of input of the target operator, e.g. for dynamic quantization
# ops, the dtype will be float32
Expand Down Expand Up @@ -1126,6 +1126,7 @@ def load_observer_state_dict(mod, obs_dict):

# Restrict activations to be in the range (0,127)
default_observer = MinMaxObserver.with_args(reduce_range=True)
default_placeholder_observer = PlaceholderObserver
default_debug_observer = RecordingObserver
default_weight_observer = MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric)
default_histogram_observer = HistogramObserver.with_args(reduce_range=True)
Expand Down
4 changes: 2 additions & 2 deletions torch/quantization/qconfig.py
Expand Up @@ -67,8 +67,8 @@ def __new__(cls, activation=torch.nn.Identity, weight=torch.nn.Identity):
per_channel_dynamic_qconfig = QConfigDynamic(activation=default_dynamic_quant_observer,
weight=default_per_channel_weight_observer)

float_qparams_dynamic_qconfig = QConfigDynamic(activation=default_dynamic_quant_observer,
weight=default_float_qparams_observer)
float_qparams_weight_only_qconfig = QConfig(activation=default_placeholder_observer,
weight=default_float_qparams_observer)

default_qat_qconfig = QConfig(activation=default_fake_quant,
weight=default_weight_fake_quant)
Expand Down
4 changes: 2 additions & 2 deletions torch/quantization/quantize.py
Expand Up @@ -18,7 +18,7 @@
)

from .stubs import DeQuantStub, QuantWrapper
from .qconfig import default_dynamic_qconfig, float16_dynamic_qconfig, float_qparams_dynamic_qconfig
from .qconfig import default_dynamic_qconfig, float16_dynamic_qconfig, float_qparams_weight_only_qconfig

def is_activation_post_process(module):
return (isinstance(module, torch.quantization.ObserverBase) or
Expand Down Expand Up @@ -348,7 +348,7 @@ def quantize_dynamic(model, qconfig_spec=None, dtype=torch.qint8,
}
elif dtype == torch.quint8:
qconfig_spec = {
nn.EmbeddingBag : float_qparams_dynamic_qconfig,
nn.EmbeddingBag : float_qparams_weight_only_qconfig,
}
else:
raise ValueError(
Expand Down
4 changes: 2 additions & 2 deletions torch/testing/_internal/common_quantization.py
Expand Up @@ -11,7 +11,7 @@
from torch.testing._internal.common_utils import TestCase
from torch.quantization import QuantWrapper, QuantStub, DeQuantStub, \
default_qconfig, default_dynamic_qconfig, default_per_channel_qconfig, QConfig, default_observer, default_weight_observer, \
propagate_qconfig_, convert, get_default_qconfig, quantize_dynamic_jit, quantize_jit, float_qparams_dynamic_qconfig, \
propagate_qconfig_, convert, get_default_qconfig, quantize_dynamic_jit, quantize_jit, float_qparams_weight_only_qconfig, \
get_default_qat_qconfig, PerChannelMinMaxObserver, default_dynamic_quant_observer, QConfigDynamic, QuantType
from torch.quantization.quantization_mappings import (
get_default_dynamic_quant_module_mappings,
Expand Down Expand Up @@ -1446,7 +1446,7 @@ def __init__(self):
super().__init__()
self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=12)
self.fc = torch.nn.Linear(5, 5)
self.emb.qconfig = float_qparams_dynamic_qconfig
self.emb.qconfig = float_qparams_weight_only_qconfig
self.qconfig = default_qconfig

def forward(self, indices, linear_in):
Expand Down