Skip to content

Commit

Permalink
[quant] Support for 4-bit quantized EmbeddingBag module (#45865)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #45865

Test Plan:
python test/test_quantization.py TestPostTrainingStatic.test_quantized_embedding_bag
python test/test_quantization.py TestStaticQuantizedModule.test_embedding_bag_api

Imported from OSS

Reviewed By: jerryzh168

Differential Revision: D24120995

fbshipit-source-id: c55fc6b2cfd683d14d2a05be7c04f787fdf8cc79
  • Loading branch information
supriyar authored and facebook-github-bot committed Oct 7, 2020
1 parent 11c3261 commit 43dc7ef
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 70 deletions.
66 changes: 38 additions & 28 deletions test/quantization/test_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
float_qparams_dynamic_qconfig,
register_observed_custom_module_mapping,
register_quantized_custom_module_mapping,
PerChannelMinMaxObserver,
QConfigDynamic,
default_dynamic_quant_observer
)

from torch.testing._internal.common_quantization import (
Expand Down Expand Up @@ -538,42 +541,49 @@ def test_quantized_embedding_bag(self):
r""" Test the post-training quantization flow, serialization and scripting
of embedding_bag modules
"""
model = EmbeddingBagModule().eval()
indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3])
offsets = torch.tensor([0, 19, 20, 28, 28, 32])
weights = torch.randn(10, 12, dtype=torch.float32)

model.qconfig = float_qparams_dynamic_qconfig
prepare(model, inplace=True)
quantized_model = convert(model)

per_sample_weights = torch.from_numpy(np.random.uniform(
low=0.01, high=0.5, size=[len(indices)]).astype(np.float32))

# Test to make sure module is quantized correctly.
self.assertTrue('QuantizedEmbeddingBag' in str(quantized_model))
self.checkDynamicQuantizedModule(quantized_model.emb, torch.nn.quantized.EmbeddingBag, torch.quint8)
self.checkScriptable(quantized_model, [[indices, offsets, per_sample_weights]], check_save_load=True)
for dtype in [torch.quint8, torch.quint4x2]:
model = EmbeddingBagModule().eval()
float_qparams_observer = PerChannelMinMaxObserver.with_args(dtype=dtype,
qscheme=torch.per_channel_affine_float_qparams,
ch_axis=0)
float_qparams_qconfig = QConfigDynamic(activation=default_dynamic_quant_observer,
weight=float_qparams_observer)
model.qconfig = float_qparams_qconfig

class EmbeddingBagWithLinear(torch.nn.Module):
def __init__(self):
super().__init__()
self.emb = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12,
include_last_offset=True, scale_grad_by_freq=False, mode='sum')
self.fc = torch.nn.Linear(5, 5)
prepare(model, inplace=True)
quantized_model = convert(model)

def forward(self, indices, offsets, per_sample_weights, linear_in):
return self.emb(indices, offsets, per_sample_weights), self.fc(linear_in)
per_sample_weights = torch.from_numpy(np.random.uniform(
low=0.01, high=0.5, size=[len(indices)]).astype(np.float32))

# Test quantization of embedding_bag layer only
model = EmbeddingBagWithLinear().eval()
model.emb.qconfig = float_qparams_dynamic_qconfig
prepare(model, inplace=True)
quantized_model = convert(model)
# Test to make sure module is quantized correctly.
self.assertTrue('QuantizedEmbeddingBag' in str(quantized_model))
self.checkDynamicQuantizedModule(quantized_model.emb, torch.nn.quantized.EmbeddingBag, torch.quint8)
self.checkScriptable(quantized_model, [[indices, offsets, per_sample_weights]], check_save_load=True)

self.assertTrue('QuantizedEmbeddingBag' in str(quantized_model))
self.checkLinear(model.fc)
self.checkDynamicQuantizedModule(quantized_model.emb, torch.nn.quantized.EmbeddingBag, torch.quint8)
class EmbeddingBagWithLinear(torch.nn.Module):
def __init__(self):
super().__init__()
self.emb = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12,
include_last_offset=True, scale_grad_by_freq=False, mode='sum')
self.fc = torch.nn.Linear(5, 5)

def forward(self, indices, offsets, per_sample_weights, linear_in):
return self.emb(indices, offsets, per_sample_weights), self.fc(linear_in)

# Test quantization of embedding_bag layer only
model2 = EmbeddingBagWithLinear().eval()
model2.emb.qconfig = float_qparams_qconfig
prepare(model2, inplace=True)
quantized_model = convert(model2)

self.assertTrue('QuantizedEmbeddingBag' in str(quantized_model))
self.checkLinear(model2.fc)
self.checkDynamicQuantizedModule(quantized_model.emb, torch.nn.quantized.EmbeddingBag, torch.quint8)

@skipIfNoFBGEMM
def test_custom_module_class(self):
Expand Down
56 changes: 33 additions & 23 deletions test/quantization/test_quantized_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import torch.quantization

from torch.quantization import (
default_float_qparams_observer
default_float_qparams_observer,
PerChannelMinMaxObserver
)
from torch.testing._internal.common_quantization import (
QuantizationTestCase,
Expand Down Expand Up @@ -742,7 +743,7 @@ def test_embedding_api(self, num_embeddings, embedding_dim, set_qconfig):
w_packed = qemb._packed_params._packed_weight
module_out = qemb(indices)

# Call the qembedding_bag operator directly
# Call the qembedding operator directly
ref = torch.ops.quantized.embedding_byte(w_packed, indices, sparse=False)
self.assertEqual(module_out, ref)
self.checkEmbeddingSerialization(qemb, num_embeddings, embedding_dim, indices, None, set_qconfig=False, is_emb_bag=False)
Expand All @@ -758,6 +759,7 @@ def test_embedding_api(self, num_embeddings, embedding_dim, set_qconfig):
def test_embedding_bag_api(self, num_embeddings, embedding_dim, num_offsets, set_qconfig):
r"""Test execution and serialization for dynamic quantized embedding_bag modules on int8
"""

num_lengths = np.random.randint(1, 6)
lengths = np.random.randint(0, 21, size=num_lengths).astype(np.int32)
num_indices = np.sum(lengths)
Expand All @@ -768,28 +770,36 @@ def test_embedding_bag_api(self, num_embeddings, embedding_dim, num_offsets, set
offsets = torch.cat((offsets, torch.tensor([indices.size(0)], dtype=torch.long)), 0)
weights = torch.from_numpy((np.random.random_sample((num_embeddings, embedding_dim)) + 1).astype(np.float32))

obs = default_float_qparams_observer()
obs(weights)
# Get the scale and zero point for the weight tensor
qparams = obs.calculate_qparams()
# Quantize the weights to 8bits
qweight = torch.quantize_per_channel(weights, qparams[0], qparams[1], axis=0, dtype=torch.quint8)
qemb = nnq.EmbeddingBag(num_embeddings=num_embeddings, embedding_dim=embedding_dim,
include_last_offset=True, mode='sum', _weight=qweight)
qemb(indices, offsets)

# Ensure the module has the correct weights
self.assertEqual(qweight, qemb.weight())

w_packed = qemb._packed_params._packed_weight
module_out = qemb(indices, offsets)
for qdtype in [torch.quint8, torch.quint4x2]:
obs = PerChannelMinMaxObserver(dtype=qdtype, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0)
obs(weights)
# Get the scale and zero point for the weight tensor
qparams = obs.calculate_qparams()
# Quantize the weights to 8bits
qweight = torch.quantize_per_channel(weights, qparams[0], qparams[1], axis=0, dtype=qdtype)
qemb = nnq.EmbeddingBag(num_embeddings=num_embeddings, embedding_dim=embedding_dim,
include_last_offset=True, mode='sum', _weight=qweight, dtype=qdtype)
qemb(indices, offsets)

# Ensure the module has the correct weights
self.assertEqual(qweight, qemb.weight())

w_packed = qemb._packed_params._packed_weight
module_out = qemb(indices, offsets)

# Call the qembedding_bag operator directly
if qdtype == torch.quint8:
ref = torch.ops.quantized.embedding_bag_byte(w_packed, indices, offsets, mode=0,
per_sample_weights=None,
include_last_offset=True)
else:
ref = torch.ops.quantized.embedding_bag_4bit(w_packed, indices, offsets, mode=0,
per_sample_weights=None,
include_last_offset=True)

# Call the qembedding_bag operator directly
ref = torch.ops.quantized.embedding_bag_byte(w_packed, indices, offsets, mode=0,
per_sample_weights=None,
include_last_offset=True)
self.assertEqual(module_out, ref)
self.checkEmbeddingSerialization(qemb, num_embeddings, embedding_dim, indices, offsets, set_qconfig, is_emb_bag=True)
self.assertEqual(module_out, ref)
self.checkEmbeddingSerialization(qemb, num_embeddings, embedding_dim, indices,
offsets, set_qconfig, is_emb_bag=True, dtype=qdtype)

class TestDynamicQuantizedModule(QuantizationTestCase):
@given(
Expand Down
33 changes: 20 additions & 13 deletions torch/nn/quantized/modules/embedding_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,31 +11,31 @@ class EmbeddingPackedParams(torch.nn.Module):
def __init__(self, num_embeddings, embedding_dim, dtype=torch.quint8):
super(EmbeddingPackedParams, self).__init__()
self.dtype = dtype
if self.dtype == torch.quint8:
if self.dtype in [torch.quint8, torch.quint4x2]:
scales = torch.ones(num_embeddings, dtype=torch.float)
zero_points = torch.zeros(num_embeddings, dtype=torch.float)
wq = torch._empty_per_channel_affine_quantized([num_embeddings, embedding_dim], scales=scales,
zero_points=zero_points,
axis=0, dtype=torch.quint8)
axis=0, dtype=self.dtype)
self.set_weight(wq)
else:
raise RuntimeError('Unsupported dtype on quantized embedding!')
raise NotImplementedError('Unsupported dtype on quantized embedding! Supports quint8 and quint4x2.')

@torch.jit.export
def set_weight(self, weight):
# type: (torch.Tensor) -> None
if self.dtype == torch.quint8:
if self.dtype in [torch.quint8, torch.quint4x2]:
self._packed_weight = torch.ops.quantized.embedding_bag_prepack(weight)
else:
raise RuntimeError('Unsupported dtype on quantized embedding!')
raise NotImplementedError('Unsupported dtype for quantized embedding prepack! Supports quint8 and quint4x2.')


@torch.jit.export
def _weight(self):
if self.dtype == torch.quint8:
if self.dtype in [torch.quint8, torch.quint4x2]:
return torch.ops.quantized.embedding_bag_unpack(self._packed_weight)
else:
raise RuntimeError('Unsupported dtype on quantized embedding!')
raise NotImplementedError('Unsupported dtype for quantized embedding unpack! Supports quint8 and quint4x2.')

def forward(self, x):
return x
Expand Down Expand Up @@ -192,17 +192,23 @@ def __init__(self, num_embeddings: int, embedding_dim: int,
max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
mode: str = 'sum', sparse: bool = False, _weight: Optional[Tensor] = None,
include_last_offset: bool = False, dtype=torch.quint8) -> None:
super(EmbeddingBag, self).__init__(num_embeddings, embedding_dim, _weight=_weight)
super(EmbeddingBag, self).__init__(num_embeddings, embedding_dim, _weight=_weight, dtype=dtype)

self.mode = mode
self.sparse = sparse
self.include_last_offset = include_last_offset
self.dtype = dtype

def forward(self, indices: Tensor, offsets: Optional[Tensor] = None, per_sample_weights: Optional[Tensor] = None,
compressed_indices_mapping: Optional[Tensor] = None) -> Tensor:
return torch.ops.quantized.embedding_bag_byte(self._packed_params._packed_weight, indices, offsets, False, 0,
self.sparse, per_sample_weights, compressed_indices_mapping,
self.include_last_offset)
if self.dtype == torch.quint4x2:
return torch.ops.quantized.embedding_bag_4bit(self._packed_params._packed_weight, indices, offsets, False, 0,
self.sparse, per_sample_weights, compressed_indices_mapping,
self.include_last_offset)
else:
return torch.ops.quantized.embedding_bag_byte(self._packed_params._packed_weight, indices, offsets, False, 0,
self.sparse, per_sample_weights, compressed_indices_mapping,
self.include_last_offset)

def _get_name(self):
return 'QuantizedEmbeddingBag'
Expand All @@ -226,13 +232,14 @@ def from_float(cls, mod):

dtype = weight_observer.dtype

assert dtype == torch.quint8, 'The only supported dtype for nnq.EmbeddingBag is torch.quint8'
assert dtype == torch.quint8 or dtype == torch.quint4x2, \
'The only supported dtype for nnq.EmbeddingBag is torch.quint8 and torch.quint4x2'

# Run the observer to calculate qparams.
weight_observer(mod.weight)
qweight = _quantize_weight(mod.weight.float(), weight_observer)

# Create quantized EmbeddingBag module and pass in the quantized weight
qembedding_bag = EmbeddingBag(mod.num_embeddings, mod.embedding_dim)
qembedding_bag = EmbeddingBag(mod.num_embeddings, mod.embedding_dim, dtype=dtype)
qembedding_bag.set_weight(qweight)
return qembedding_bag
2 changes: 1 addition & 1 deletion torch/nn/quantized/modules/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def _quantize_weight(float_wt, observer):
elif observer.qscheme in [torch.per_channel_affine_float_qparams]:
qweight = torch.quantize_per_channel(
float_wt,
wt_scale.to(torch.float), wt_zp.to(torch.float), observer.ch_axis, torch.quint8)
wt_scale.to(torch.float), wt_zp.to(torch.float), observer.ch_axis, observer.dtype)
else:
raise ValueError("Unexpected qscheme " + observer.qscheme)
return qweight
Expand Down
15 changes: 10 additions & 5 deletions torch/testing/_internal/common_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from torch.quantization import QuantWrapper, QuantStub, DeQuantStub, \
default_qconfig, default_dynamic_qconfig, default_per_channel_qconfig, QConfig, default_observer, default_weight_observer, \
propagate_qconfig_, convert, get_default_qconfig, quantize_dynamic_jit, quantize_jit, float_qparams_dynamic_qconfig, \
get_default_qat_qconfig
get_default_qat_qconfig, PerChannelMinMaxObserver, default_dynamic_quant_observer, QConfigDynamic
from torch.quantization import (
is_custom_module_class,
is_observed_custom_module,
Expand Down Expand Up @@ -667,7 +667,8 @@ def checkGraphModeFxOp(self, model, inputs, quant_type,
qgraph_to_check, expected_node, expected_node_occurrence, expected_node_list)


def checkEmbeddingSerialization(self, qemb, num_embeddings, embedding_dim, indices, offsets, set_qconfig, is_emb_bag):
def checkEmbeddingSerialization(self, qemb, num_embeddings, embedding_dim, indices, offsets,
set_qconfig, is_emb_bag, dtype=torch.quint8):
# Test serialization of dynamic EmbeddingBag module using state_dict
if is_emb_bag:
inputs = [indices, offsets]
Expand All @@ -690,9 +691,9 @@ def checkEmbeddingSerialization(self, qemb, num_embeddings, embedding_dim, indic
# Check state dict serialization and torch.save APIs
if is_emb_bag:
loaded_qemb = nnq.EmbeddingBag(num_embeddings=num_embeddings, embedding_dim=embedding_dim,
include_last_offset=True, mode='sum')
include_last_offset=True, mode='sum', dtype=dtype)
else:
loaded_qemb = nnq.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
loaded_qemb = nnq.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim, dtype=dtype)
self.check_eager_serialization(qemb, loaded_qemb, inputs)

loaded_qemb.load_state_dict(loaded_dict)
Expand All @@ -711,7 +712,11 @@ def checkEmbeddingSerialization(self, qemb, num_embeddings, embedding_dim, indic
float_embedding = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)

if set_qconfig:
float_embedding.qconfig = float_qparams_dynamic_qconfig
float_qparams_observer = PerChannelMinMaxObserver.with_args(dtype=dtype,
qscheme=torch.per_channel_affine_float_qparams,
ch_axis=0)
float_embedding.qconfig = QConfigDynamic(activation=default_dynamic_quant_observer,
weight=float_qparams_observer)

prepare_dynamic(float_embedding)

Expand Down

0 comments on commit 43dc7ef

Please sign in to comment.