Skip to content

Commit

Permalink
[quant][fix] Fix quant type classification for float_qparam qconfig (#…
Browse files Browse the repository at this point in the history
…48069)

Summary:
Pull Request resolved: #48069

also renamed float_qparam_dynamic_qconfig to float_qparam_weight_only_qconfig
It's not used in user code yet so we only need to update the tests.

Test Plan: Imported from OSS

Reviewed By: supriyar

Differential Revision: D25010175

fbshipit-source-id: caa3eaa5358a8bc5c808bf5f64e6ebff3e0b61e8
  • Loading branch information
jerryzh168 authored and facebook-github-bot committed Nov 19, 2020
1 parent f0f8b97 commit 576fa09
Show file tree
Hide file tree
Showing 9 changed files with 29 additions and 24 deletions.
4 changes: 2 additions & 2 deletions test/quantization/test_quantize.py
Expand Up @@ -22,7 +22,7 @@
default_dynamic_qconfig,
per_channel_dynamic_qconfig,
float16_dynamic_qconfig,
float_qparams_dynamic_qconfig,
float_qparams_weight_only_qconfig,
PerChannelMinMaxObserver,
QConfigDynamic,
default_dynamic_quant_observer,
Expand Down Expand Up @@ -521,7 +521,7 @@ def test_quantized_embedding(self):
model = EmbeddingModule().eval()
indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3])
weights = torch.randn(10, 12, dtype=torch.float32)
model.qconfig = float_qparams_dynamic_qconfig
model.qconfig = float_qparams_weight_only_qconfig
prepare(model, inplace=True)
convert(model, inplace=True)
self.assertTrue('QuantizedEmbedding' in str(model))
Expand Down
8 changes: 4 additions & 4 deletions test/quantization/test_quantize_fx.py
Expand Up @@ -21,16 +21,16 @@
quant_type_to_str,
default_qconfig,
default_dynamic_qconfig,
default_dynamic_quant_observer,
default_qat_qconfig,
float16_dynamic_qconfig,
float_qparams_dynamic_qconfig,
float_qparams_weight_only_qconfig,
get_default_qconfig,
get_default_qat_qconfig,
fuse_modules,
prepare,
prepare_qat,
convert,
default_placeholder_observer,
PerChannelMinMaxObserver,
QConfigDynamic,
FixedQParamsFakeQuantize,
Expand Down Expand Up @@ -1947,7 +1947,7 @@ def forward(self, indices):
indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3])
quantized_node = ns.call_module(nnq.Embedding)
configs = [
(float_qparams_dynamic_qconfig, ns.call_module(nnq.Embedding)),
(float_qparams_weight_only_qconfig, ns.call_module(nnq.Embedding)),
(None, ns.call_module(nn.Embedding)),
(default_qconfig, ns.call_module(nn.Embedding)),
]
Expand Down Expand Up @@ -1982,7 +1982,7 @@ def forward(self, indices, offsets):
float_qparams_observer = PerChannelMinMaxObserver.with_args(dtype=dtype,
qscheme=torch.per_channel_affine_float_qparams,
ch_axis=0)
float_qparams_qconfig = QConfigDynamic(activation=default_dynamic_quant_observer,
float_qparams_qconfig = QConfigDynamic(activation=default_placeholder_observer,
weight=float_qparams_observer)
self.checkGraphModeFxOp(
model,
Expand Down
8 changes: 4 additions & 4 deletions torch/nn/quantized/modules/embedding_ops.py
Expand Up @@ -144,11 +144,11 @@ def from_float(cls, mod):
assert type(mod) == nn.Embedding, 'nnq.' + cls.__name__ + '.from_float only works for ' + \
nn.Embedding.__name__
assert hasattr(mod, 'qconfig'), 'Embedding input float module must have qconfig defined'
from torch.quantization.qconfig import float_qparams_dynamic_qconfig
from torch.quantization import float_qparams_weight_only_qconfig
if mod.qconfig is not None and mod.qconfig.weight is not None:
weight_observer = mod.qconfig.weight()
else:
weight_observer = float_qparams_dynamic_qconfig.weight()
weight_observer = float_qparams_weight_only_qconfig.weight()

dtype = weight_observer.dtype

Expand Down Expand Up @@ -224,11 +224,11 @@ def from_float(cls, mod):
assert type(mod) == nn.EmbeddingBag, 'nnq.' + cls.__name__ + '.from_float only works for ' + \
nn.EmbeddingBag.__name__
assert hasattr(mod, 'qconfig'), 'EmbeddingBag input float module must have qconfig defined'
from torch.quantization.qconfig import float_qparams_dynamic_qconfig
from torch.quantization.qconfig import float_qparams_weight_only_qconfig
if mod.qconfig is not None and mod.qconfig.weight is not None:
weight_observer = mod.qconfig.weight()
else:
weight_observer = float_qparams_dynamic_qconfig.weight()
weight_observer = float_qparams_weight_only_qconfig.weight()

dtype = weight_observer.dtype

Expand Down
3 changes: 2 additions & 1 deletion torch/quantization/__init__.py
Expand Up @@ -43,7 +43,7 @@ def default_eval_fn(model, calib_data):
'register_activation_post_process_hook',
# Observers
'ObserverBase', 'WeightObserver', 'observer', 'default_observer',
'default_weight_observer',
'default_weight_observer', 'default_placeholder_observer',
# FakeQuantize (for qat)
'default_fake_quant', 'default_weight_fake_quant',
'default_symmetric_fixed_qparams_fake_quant',
Expand All @@ -52,6 +52,7 @@ def default_eval_fn(model, calib_data):
'default_histogram_fake_quant',
# QConfig
'QConfig', 'default_qconfig', 'default_dynamic_qconfig', 'float16_dynamic_qconfig',
'float_qparams_weight_only_qconfig',
# QAT utilities
'default_qat_qconfig', 'prepare_qat', 'quantize_qat',
# module transformations
Expand Down
10 changes: 5 additions & 5 deletions torch/quantization/fx/quantization_patterns.py
Expand Up @@ -432,13 +432,13 @@ def __init__(self, quantizer, node):

def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_dict=None):
# Supported combinations are:
# quant_type | activation (compute_type) | weight
# weight_only | float32 (torch.uint8) | quint8
# weight_only | float32 (torch.uint8) | quint4x2
# quant_type | activation | weight | activation_compute_type
# weight_only | float32 | quint8 | None
# weight_only | float32 | quint4x2 | None
# tuple (activation_dtype, weight_dtype, compute_dtype)
supported_dtypes = [
(torch.float32, torch.quint8, torch.quint8),
(torch.float32, torch.quint4x2, torch.quint8),
(torch.float32, torch.quint8, None),
(torch.float32, torch.quint4x2, None),
]
assert node.op == 'call_module'
emb_node = node
Expand Down
3 changes: 2 additions & 1 deletion torch/quantization/observer.py
Expand Up @@ -989,7 +989,7 @@ class PlaceholderObserver(ObserverBase):
custom_op_name: (temporary) specify this observer for an operator that doesn't require any observation
(Can be used in Graph Mode Passes for special case ops).
"""
def __init__(self, dtype=torch.float16, custom_op_name="", compute_dtype=None):
def __init__(self, dtype=torch.float32, custom_op_name="", compute_dtype=None):
super(PlaceholderObserver, self).__init__(dtype=dtype)
# dtype of input of the target operator, e.g. for dynamic quantization
# ops, the dtype will be float32
Expand Down Expand Up @@ -1126,6 +1126,7 @@ def load_observer_state_dict(mod, obs_dict):

# Restrict activations to be in the range (0,127)
default_observer = MinMaxObserver.with_args(reduce_range=True)
default_placeholder_observer = PlaceholderObserver
default_debug_observer = RecordingObserver
default_weight_observer = MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric)
default_histogram_observer = HistogramObserver.with_args(reduce_range=True)
Expand Down
7 changes: 5 additions & 2 deletions torch/quantization/qconfig.py
Expand Up @@ -67,8 +67,11 @@ def __new__(cls, activation=torch.nn.Identity, weight=torch.nn.Identity):
per_channel_dynamic_qconfig = QConfigDynamic(activation=default_dynamic_quant_observer,
weight=default_per_channel_weight_observer)

float_qparams_dynamic_qconfig = QConfigDynamic(activation=default_dynamic_quant_observer,
weight=default_float_qparams_observer)
# TODO: this is weight only quant, change this to QConfigWeightOnly
# or remove the QConfigDynamic later
float_qparams_weight_only_qconfig = QConfigDynamic(
activation=default_placeholder_observer,
weight=default_float_qparams_observer)

default_qat_qconfig = QConfig(activation=default_fake_quant,
weight=default_weight_fake_quant)
Expand Down
6 changes: 3 additions & 3 deletions torch/quantization/quantize.py
Expand Up @@ -17,7 +17,7 @@
)

from .stubs import DeQuantStub, QuantWrapper
from .qconfig import default_dynamic_qconfig, float16_dynamic_qconfig, float_qparams_dynamic_qconfig
from .qconfig import default_dynamic_qconfig, float16_dynamic_qconfig, float_qparams_weight_only_qconfig

def is_activation_post_process(module):
return (isinstance(module, torch.quantization.ObserverBase) or
Expand Down Expand Up @@ -352,7 +352,7 @@ def quantize_dynamic(model, qconfig_spec=None, dtype=torch.qint8,
}
elif dtype == torch.quint8:
qconfig_spec = {
nn.EmbeddingBag : float_qparams_dynamic_qconfig,
nn.EmbeddingBag : float_qparams_weight_only_qconfig,
}
else:
raise ValueError(
Expand All @@ -363,7 +363,7 @@ def quantize_dynamic(model, qconfig_spec=None, dtype=torch.qint8,
elif dtype is torch.float16:
default_qconfig = float16_dynamic_qconfig
elif dtype is torch.quint8:
default_qconfig = float_qparams_dynamic_qconfig
default_qconfig = float_qparams_weight_only_qconfig
else:
raise RuntimeError('Unknown dtype specified for quantize_dynamic: ', str(dtype))
qconfig_spec = dict(zip(qconfig_spec, itertools.repeat(default_qconfig)))
Expand Down
4 changes: 2 additions & 2 deletions torch/testing/_internal/common_quantization.py
Expand Up @@ -12,7 +12,7 @@
from torch.testing._internal.common_utils import TestCase
from torch.quantization import QuantWrapper, QuantStub, DeQuantStub, \
default_qconfig, default_dynamic_qconfig, default_per_channel_qconfig, QConfig, default_observer, default_weight_observer, \
propagate_qconfig_, convert, get_default_qconfig, quantize_dynamic_jit, quantize_jit, float_qparams_dynamic_qconfig, \
propagate_qconfig_, convert, get_default_qconfig, quantize_dynamic_jit, quantize_jit, float_qparams_weight_only_qconfig, \
get_default_qat_qconfig, PerChannelMinMaxObserver, default_dynamic_quant_observer, QConfigDynamic, QuantType
from torch.quantization.quantization_mappings import (
get_default_dynamic_quant_module_mappings,
Expand Down Expand Up @@ -1449,7 +1449,7 @@ def __init__(self):
super().__init__()
self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=12)
self.fc = torch.nn.Linear(5, 5)
self.emb.qconfig = float_qparams_dynamic_qconfig
self.emb.qconfig = float_qparams_weight_only_qconfig
self.qconfig = default_qconfig

def forward(self, indices, linear_in):
Expand Down

0 comments on commit 576fa09

Please sign in to comment.