Skip to content

Commit

Permalink
[quant] Quantizable LSTM
Browse files Browse the repository at this point in the history
ghstack-source-id: afddb9620f0378a22c331ab5a8f75d3d5589bbcc
Pull Request resolved: #49671
  • Loading branch information
z-a-f committed Dec 21, 2020
1 parent 3b51076 commit e5c2b8d
Show file tree
Hide file tree
Showing 6 changed files with 404 additions and 0 deletions.
103 changes: 103 additions & 0 deletions test/quantization/test_quantized_op.py
Expand Up @@ -2734,6 +2734,109 @@ def test_qrnncell(self, num_batches, input_size, hidden_size, per_channel_quant)
self.assertEqual(result_ref[0], result_dynamic[0], msg="torch.quantized_rnncell results are off")


"""Tests recurrent layers."""
class TestPTQRNN(TestCase):
def _snr(self, x, x_hat):
if isinstance(x, (list, tuple)):
assert(len(x) == len(x_hat))
res = []
for idx in range(len(x)):
res.append(self._snr(x[idx], x_hat[idx]))
return res
if x_hat.is_quantized:
x_hat = x_hat.dequantize()
noise = (x - x_hat).square().mean().item()
if noise == 0:
return 0.0, float('inf'), float('inf')
signal = x.square().mean().item()
snr = signal / noise
snr_db = 10 * np.log10(snr)
return signal, noise, snr_db

@override_qengines
def test_lstm(self):
qengine = torch.backends.quantized.engine

batch_size = 4
seq_len = 8
input_size = 12

hidden_size = 8
num_layers = 2

dropout = 0 # This is not supported

Bias = [False, True]
Batch_first = [False, True]
Bidirectional = [False, True]

dtype = np.uint8
qtype = torch.quint8

custom_module_config = {
'float_to_observed_custom_module_class': {
torch.nn.LSTM: torch.nn.quantizable.LSTM
}
}

# Moving the mean of the random input to avoid signal being close to 0
x = np.random.randn(seq_len, batch_size, input_size)
scale, zero_point = _calculate_dynamic_qparams(x, dtype=dtype)
qx = _quantize(x, scale=scale, zero_point=zero_point, dtype=dtype)
x = torch.from_numpy(x).to(torch.float)
qx = torch.quantize_per_tensor(x, scale=scale, zero_point=zero_point,
dtype=qtype)
x = qx.dequantize()

with torch.no_grad():
for bias, batch_first, bidirectional in itertools.product(
Bias, Batch_first, Bidirectional):
# Assume 12dB is sufficient for functional equivalence
# Without the bias, linear performs poorly
min_power = 12 if bias else 6
max_mse = 1e-6 if bias else 1e-1

if batch_first:
x = x.reshape(batch_size, seq_len, input_size)
qx = qx.reshape(batch_size, seq_len, input_size)
else:
x = x.reshape(seq_len, batch_size, input_size)
qx = qx.reshape(seq_len, batch_size, input_size)

lstm = torch.nn.Sequential(
torch.nn.LSTM(input_size, hidden_size,
num_layers=num_layers,
bias=bias, batch_first=batch_first,
dropout=dropout,
bidirectional=bidirectional))
lstm.eval()
y_ref = lstm(x)

# Prepare
lstm.qconfig = torch.quantization.get_default_qconfig(qengine)
lstm_prepared = torch.quantization.prepare(lstm,
prepare_custom_config_dict=custom_module_config)
self.assertTrue(hasattr(lstm_prepared[0], 'layers'))
self.assertEqual(num_layers, len(lstm_prepared[0].layers))

# Calibrate
y = lstm_prepared(x)
self.assertEqual(y_ref, y)

# Quantize
lstm_quantized = torch.quantization.convert(lstm_prepared)
qy = lstm_quantized(qx)

snr = self._snr(y, qy)
snr = [snr[0]] + snr[1]

for signal, mse, power in snr:
self.assertTrue(
power > min_power or mse < max_mse,
msg=(f"Error is too high: SNR(dB): {power}, "
f"Signal: {signal}, MSE: {mse}"))


class TestQuantizedLinear(unittest.TestCase):
"""Tests the correctness of the quantized linear and linear_relu op."""
@given(batch_size=st.integers(1, 4),
Expand Down
1 change: 1 addition & 0 deletions test/test_quantization.py
Expand Up @@ -14,6 +14,7 @@
from quantization.test_quantized_op import TestComparatorOps # noqa: F401
from quantization.test_quantized_op import TestPadding # noqa: F401
from quantization.test_quantized_op import TestQuantizedEmbeddingOps # noqa: F401
from quantization.test_quantized_op import TestPTQRNN # noqa: F401

# Quantized Functional
from quantization.test_quantized_functional import TestQuantizedFunctional # noqa: F401
Expand Down
1 change: 1 addition & 0 deletions torch/__init__.py
Expand Up @@ -574,6 +574,7 @@ def _assert(condition, message):
import torch.futures
import torch.nn
import torch.nn.intrinsic
import torch.nn.quantizable
import torch.nn.quantized
import torch.optim
import torch.optim._multi_tensor
Expand Down
1 change: 1 addition & 0 deletions torch/nn/quantizable/__init__.py
@@ -0,0 +1 @@
from .modules import *
7 changes: 7 additions & 0 deletions torch/nn/quantizable/modules/__init__.py
@@ -0,0 +1,7 @@
from .rnn import LSTM
from .rnn import LSTMCell

__all__ = [
'LSTM',
'LSTMCell',
]

0 comments on commit e5c2b8d

Please sign in to comment.