diff --git a/test/quantization/test_quantize.py b/test/quantization/test_quantize.py index 1198466b83a0..71391a9ddee7 100644 --- a/test/quantization/test_quantize.py +++ b/test/quantization/test_quantize.py @@ -89,6 +89,7 @@ import numpy as np class TestPostTrainingStatic(QuantizationTestCase): + def test_single_layer(self): r"""Quantize SingleLayerLinearModel which has one Linear module, make sure it is swapped to nnq.Linear which is the quantized version of the module @@ -537,6 +538,29 @@ def test_quantized_embedding(self): self.assertTrue('QuantizedLinear' in str(model)) self.checkQuantizedLinear(model.fc) + @skipIfNoFBGEMM + def test_embedding_linear_dynamic(self): + class EmbeddingWithLinearDynamic(torch.nn.Module): + def __init__(self): + super().__init__() + self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=12) + self.fc = torch.nn.Linear(5, 5) + + def forward(self, indices, linear_in): + return self.emb(indices), self.fc(linear_in) + + model = EmbeddingWithLinearDynamic() + qconfig_dict = {'fc' : default_dynamic_qconfig} + model = EmbeddingWithLinear() + quantize_dynamic(model, qconfig_dict, inplace=True) + + model.emb.qconfig = float_qparams_weight_only_qconfig + prepare(model, inplace=True) + convert(model, inplace=True) + self.assertTrue('QuantizedEmbedding' in str(model)) + self.assertTrue('DynamicQuantizedLinear' in str(model)) + + @skipIfNoFBGEMM def test_dequant_stub(self): m = QuantStubModel().eval() diff --git a/torch/nn/quantized/modules/embedding_ops.py b/torch/nn/quantized/modules/embedding_ops.py index 523994b364c8..b12591a6d1de 100644 --- a/torch/nn/quantized/modules/embedding_ops.py +++ b/torch/nn/quantized/modules/embedding_ops.py @@ -97,16 +97,16 @@ def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optiona if _weight is None: scales = torch.ones(num_embeddings, dtype=torch.float) zero_points = torch.zeros(num_embeddings, dtype=torch.float) - self.qweight = torch._empty_per_channel_affine_quantized([num_embeddings, embedding_dim], - scales=scales, zero_points=zero_points, - axis=0, dtype=torch.quint8) + qweight = torch._empty_per_channel_affine_quantized([num_embeddings, embedding_dim], + scales=scales, zero_points=zero_points, + axis=0, dtype=torch.quint8) else: assert list(_weight.shape) == [num_embeddings, embedding_dim], \ 'Shape of weight does not match num_embeddings and embedding_dim' - self.qweight = _weight + qweight = _weight self._packed_params = EmbeddingPackedParams(num_embeddings, embedding_dim, dtype) - self._packed_params.set_weight(self.qweight) + self._packed_params.set_weight(qweight) def forward(self, indices: Tensor) -> Tensor: return torch.ops.quantized.embedding_byte(self._packed_params._packed_weight, indices) @@ -119,7 +119,7 @@ def __repr__(self): def extra_repr(self): extra_repr_str = 'num_embeddings={}, embedding_dim={}, dtype={}, qscheme={}'.format( - self.num_embeddings, self.embedding_dim, self._packed_params.dtype, self.qweight.qscheme() + self.num_embeddings, self.embedding_dim, self._packed_params.dtype, self.weight().qscheme() ) return extra_repr_str