Skip to content

Commit

Permalink
[quant] update embedding module to not store qweight
Browse files Browse the repository at this point in the history
Summary:
previously we were storing the quantized weight as a module attribute, whcih
was resulting in the weight getting stored as part of the model.
We don't need this since we already store the unpacked weights as part of the model.

Test Plan:
Before
```
Archive:  tmp.pt
 Length   Method    Size  Cmpr    Date    Time   CRC-32   Name
--------  ------  ------- ---- ---------- ----- --------  ----
     586  Stored      586   0% 00-00-1980 00:00 5fefdda0  tmp/extra/producer_info.json
 1588700  Stored  1588700   0% 00-00-1980 00:00 04e0da4c  tmp/data/0
   63548  Stored    63548   0% 00-00-1980 00:00 0ceb1f45  tmp/data/1
   63548  Stored    63548   0% 00-00-1980 00:00 517bc3ab  tmp/data/2
 1588700  Stored  1588700   0% 00-00-1980 00:00 dbe88c73  tmp/data/3
   63548  Stored    63548   0% 00-00-1980 00:00 d8dc47c4  tmp/data/4
   63548  Stored    63548   0% 00-00-1980 00:00 b9e0c20f  tmp/data/5
    1071  Stored     1071   0% 00-00-1980 00:00 10dc9350  tmp/data.pkl
     327  Defl:N      203  38% 00-00-1980 00:00 dfddb661  tmp/code/__torch__/___torch_mangle_0.py
     185  Stored      185   0% 00-00-1980 00:00 308f580b  tmp/code/__torch__/___torch_mangle_0.py.debug_pkl
    1730  Defl:N      515  70% 00-00-1980 00:00 aa11f799  tmp/code/__torch__/torch/nn/quantized/modules/embedding_ops.py
    1468  Defl:N      636  57% 00-00-1980 00:00 779609a6  tmp/code/__torch__/torch/nn/quantized/modules/embedding_ops.py.debug_pkl
       0  Stored        0   0% 00-00-1980 00:00 00000000  tmp/code/__torch__/torch/classes/quantized.py
       6  Stored        6   0% 00-00-1980 00:00 816d0907  tmp/code/__torch__/torch/classes/quantized.py.debug_pkl
       4  Stored        4   0% 00-00-1980 00:00 57092f6  tmp/constants.pkl
       2  Stored        2   0% 00-00-1980 00:00 55679ed1  tmp/version
--------          -------  ---                            -------
 3436971          3434800   0%                            16 files
```
After
```
Archive:  tmp.pt
 Length   Method    Size  Cmpr    Date    Time   CRC-32   Name
--------  ------  ------- ---- ---------- ----- --------  ----
 1588700  Stored  1588700   0% 00-00-1980 00:00 a4da6981  tmp/data/0
   63548  Stored    63548   0% 00-00-1980 00:00 74d9b607  tmp/data/1
   63548  Stored    63548   0% 00-00-1980 00:00 e346a0c2  tmp/data/2
     952  Stored      952   0% 00-00-1980 00:00 eff8706e  tmp/data.pkl
     375  Defl:N      227  40% 00-00-1980 00:00 96c77b68  tmp/code/__torch__/quantization/test_quantize/___torch_mangle_23.py
     228  Defl:N      162  29% 00-00-1980 00:00 6a378113  tmp/code/__torch__/quantization/test_quantize/___torch_mangle_23.py.debug_pkl
    1711  Defl:N      509  70% 00-00-1980 00:00 66d8fd61  tmp/code/__torch__/torch/nn/quantized/modules/embedding_ops.py
    1473  Defl:N      634  57% 00-00-1980 00:00 beb2323b  tmp/code/__torch__/torch/nn/quantized/modules/embedding_ops.py.debug_pkl
       0  Stored        0   0% 00-00-1980 00:00 00000000  tmp/code/__torch__/torch/classes/quantized.py
       6  Stored        6   0% 00-00-1980 00:00 816d0907  tmp/code/__torch__/torch/classes/quantized.py.debug_pkl
       4  Stored        4   0% 00-00-1980 00:00 57092f6  tmp/constants.pkl
       2  Stored        2   0% 00-00-1980 00:00 55679ed1  tmp/version
--------          -------  ---                            -------
 1720547          1718292   0%                            12 files
```
Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: cf0e58cb0679e62e28dac92009f1c564e126e300
Pull Request resolved: #50418
  • Loading branch information
supriyar committed Jan 14, 2021
1 parent 9ebea77 commit 63a24e0
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 6 deletions.
24 changes: 24 additions & 0 deletions test/quantization/test_quantize.py
Expand Up @@ -89,6 +89,7 @@
import numpy as np

class TestPostTrainingStatic(QuantizationTestCase):

def test_single_layer(self):
r"""Quantize SingleLayerLinearModel which has one Linear module, make sure it is swapped
to nnq.Linear which is the quantized version of the module
Expand Down Expand Up @@ -537,6 +538,29 @@ def test_quantized_embedding(self):
self.assertTrue('QuantizedLinear' in str(model))
self.checkQuantizedLinear(model.fc)

@skipIfNoFBGEMM
def test_embedding_linear_dynamic(self):
class EmbeddingWithLinearDynamic(torch.nn.Module):
def __init__(self):
super().__init__()
self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=12)
self.fc = torch.nn.Linear(5, 5)

def forward(self, indices, linear_in):
return self.emb(indices), self.fc(linear_in)

model = EmbeddingWithLinearDynamic()
qconfig_dict = {'fc' : default_dynamic_qconfig}
model = EmbeddingWithLinear()
quantize_dynamic(model, qconfig_dict, inplace=True)

model.emb.qconfig = float_qparams_weight_only_qconfig
prepare(model, inplace=True)
convert(model, inplace=True)
self.assertTrue('QuantizedEmbedding' in str(model))
self.assertTrue('DynamicQuantizedLinear' in str(model))


@skipIfNoFBGEMM
def test_dequant_stub(self):
m = QuantStubModel().eval()
Expand Down
12 changes: 6 additions & 6 deletions torch/nn/quantized/modules/embedding_ops.py
Expand Up @@ -97,16 +97,16 @@ def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optiona
if _weight is None:
scales = torch.ones(num_embeddings, dtype=torch.float)
zero_points = torch.zeros(num_embeddings, dtype=torch.float)
self.qweight = torch._empty_per_channel_affine_quantized([num_embeddings, embedding_dim],
scales=scales, zero_points=zero_points,
axis=0, dtype=torch.quint8)
qweight = torch._empty_per_channel_affine_quantized([num_embeddings, embedding_dim],
scales=scales, zero_points=zero_points,
axis=0, dtype=torch.quint8)
else:
assert list(_weight.shape) == [num_embeddings, embedding_dim], \
'Shape of weight does not match num_embeddings and embedding_dim'
self.qweight = _weight
qweight = _weight

self._packed_params = EmbeddingPackedParams(num_embeddings, embedding_dim, dtype)
self._packed_params.set_weight(self.qweight)
self._packed_params.set_weight(qweight)

def forward(self, indices: Tensor) -> Tensor:
return torch.ops.quantized.embedding_byte(self._packed_params._packed_weight, indices)
Expand All @@ -119,7 +119,7 @@ def __repr__(self):

def extra_repr(self):
extra_repr_str = 'num_embeddings={}, embedding_dim={}, dtype={}, qscheme={}'.format(
self.num_embeddings, self.embedding_dim, self._packed_params.dtype, self.qweight.qscheme()
self.num_embeddings, self.embedding_dim, self._packed_params.dtype, self.weight().qscheme()
)

return extra_repr_str
Expand Down

0 comments on commit 63a24e0

Please sign in to comment.