-
Notifications
You must be signed in to change notification settings - Fork 24.9k
Description
🐛 Describe the bug
It seems like Quantization Aware Training with Torch>=10.2 does not support nn.Embedding layers. The following example code:
class Model(nn.Module):
def __init__(self):
super().__init__()
self.embed = nn.Embedding(100,100)
def forward(self,x):
return self.embed(x)
model = Model()
model.train()
model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
model = torch.quantization.prepare_qat(model)
torch.quantization.convert(model)
Runs into the following error:
torch/nn/quantized/modules/embedding_ops.py:162, in Embedding.from_float(cls, mod)
160 dtype = weight_observer.dtype
161 is_float_qparams_qconfig = weight_observer.qscheme == torch.per_channel_affine_float_qparams
--> 162 assert is_float_qparams_qconfig, \
163 'Embedding quantization is only supported with float_qparams_weight_only_qconfig.'
165 assert dtype == torch.quint8 or dtype == torch.quint4x2, \
166 f'The only supported dtype for nnq.Embedding is torch.quint8 and torch.quint4x2, got {dtype}'
168 # Run the observer to calculate qparams.
AssertionError: Embedding quantization is only supported with float_qparams_weight_only_qconfig.
Issues #41396 and #65185 offer solutions for static and dynamic quantization but not for quantization aware training. Is there any solution to this? If not, fixing this issue would be interesting for applications in transformers among others.
Thanks for the help!
Versions
PyTorch version: 1.12.1+cu102
Is debug build: False
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A
OS: CentOS Linux 7 (Core) (x86_64)
GCC version: (GCC) 8.3.1 20190311 (Red Hat 8.3.1-3)
Clang version: Could not collect
CMake version: version 2.8.12.2
Libc version: glibc-2.17
Python version: 3.9.12 (main, Sep 14 2022, 07:55:59) [GCC 4.8.5 20150623 (Red Hat 4.8.5-44)] (64-bit runtime)
Python platform: Linux-5.4.86-1.el7.x86_64-x86_64-with-glibc2.17
Is CUDA available: True
CUDA runtime version: 10.2.89
CUDA_MODULE_LOADING set to:
GPU models and configuration:
GPU 0: Tesla V100-PCIE-16GB
GPU 1: Tesla V100-PCIE-16GB
Nvidia driver version: 440.64.00
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] gpytorch==1.9.0
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.23.3
[pip3] torch==1.12.1
[pip3] torch-tb-profiler==0.4.0
[conda] Could not collect