Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[quant] Quantizable LSTM #49671

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
83 changes: 82 additions & 1 deletion test/quantization/test_quantized_op.py
Expand Up @@ -23,7 +23,7 @@
from torch.testing._internal.common_utils import IS_PPC, TEST_WITH_UBSAN, IS_MACOS
from torch.testing._internal.common_quantization import skipIfNoFBGEMM
from torch.testing._internal.common_quantized import _quantize, _dequantize, _calculate_dynamic_qparams, \
override_quantized_engine, supported_qengines, override_qengines
override_quantized_engine, supported_qengines, override_qengines, _snr
from torch.testing._internal.common_quantized import qengine_is_qnnpack
from torch.quantization import PerChannelMinMaxObserver

Expand Down Expand Up @@ -2314,6 +2314,87 @@ def test_advanced_indexing(self):
torch.quantize_per_tensor(x_fp32_s4, scale, zp, dtype)
self.assertEqual(x_q_s4, x_fp32_s4_ref)

@override_qengines
def test_custom_module_lstm(self):
qengine = torch.backends.quantized.engine

batch_size = 4
seq_len = 8
input_size = 12

hidden_size = 8
num_layers = 2

dropout = 0 # This is not supported

Bias = [False, True]
Batch_first = [False, True]
Bidirectional = [False, True]

dtype = np.uint8
qtype = torch.quint8

custom_module_config = {
'float_to_observed_custom_module_class': {
torch.nn.LSTM: torch.nn.quantizable.LSTM
}
}

x = np.random.randn(seq_len, batch_size, input_size)
scale, zero_point = _calculate_dynamic_qparams(x, dtype=dtype)
x = torch.from_numpy(x).to(torch.float)
qx = torch.quantize_per_tensor(x, scale=scale, zero_point=zero_point,
dtype=qtype)
x = qx.dequantize()

with torch.no_grad():
for bias, batch_first, bidirectional in itertools.product(
Bias, Batch_first, Bidirectional):
# Assume 12dB is sufficient for functional equivalence
# Without the bias, linear performs poorly
min_power = 10 if bias else 5
max_mse = 5e-6 if bias else 5e-1

if batch_first:
x = x.reshape(batch_size, seq_len, input_size)
qx = qx.reshape(batch_size, seq_len, input_size)
else:
x = x.reshape(seq_len, batch_size, input_size)
qx = qx.reshape(seq_len, batch_size, input_size)

lstm = torch.nn.Sequential(
torch.nn.LSTM(input_size, hidden_size,
num_layers=num_layers,
bias=bias, batch_first=batch_first,
dropout=dropout,
bidirectional=bidirectional))
lstm.eval()
y_ref = lstm(x)

# Prepare
lstm.qconfig = torch.quantization.get_default_qconfig(qengine)
lstm_prepared = torch.quantization.prepare(
lstm, prepare_custom_config_dict=custom_module_config)
self.assertTrue(hasattr(lstm_prepared[0], 'layers'))
self.assertEqual(num_layers, len(lstm_prepared[0].layers))

# Calibrate
y = lstm_prepared(x)
self.assertEqual(y_ref, y)

# Quantize
lstm_quantized = torch.quantization.convert(lstm_prepared)
qy = lstm_quantized(qx)

snr = _snr(y, qy)
snr = [snr[0]] + snr[1]

for signal, mse, power in snr:
self.assertTrue(
power > min_power or mse < max_mse,
msg=(f"Error is too high: SNR(dB): {power}, "
f"Signal: {signal}, MSE: {mse}"))


class TestDynamicQuantizedLinear(TestCase):
"""Tests the correctness of the dynamic quantized linear and linear_relu op."""
Expand Down
1 change: 1 addition & 0 deletions test/test_quantization.py
Expand Up @@ -15,6 +15,7 @@
from quantization.test_quantized_op import TestPadding # noqa: F401
from quantization.test_quantized_op import TestQuantizedEmbeddingOps # noqa: F401
from quantization.test_quantized_op import TestDynamicQuantizedRNNOp # noqa: F401

# Quantized Functional
from quantization.test_quantized_functional import TestQuantizedFunctional # noqa: F401

Expand Down
1 change: 1 addition & 0 deletions torch/__init__.py
Expand Up @@ -574,6 +574,7 @@ def _assert(condition, message):
import torch.futures
import torch.nn
import torch.nn.intrinsic
import torch.nn.quantizable
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we put this under torch.nn.mo.quantizable since we are planning to move everything under mo?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking about it -- but currently mo is in the fb namespace. Should I start introducing the new namespace?

import torch.nn.quantized
import torch.optim
import torch.optim._multi_tensor
Expand Down
1 change: 1 addition & 0 deletions torch/nn/quantizable/__init__.py
@@ -0,0 +1 @@
from .modules import *
7 changes: 7 additions & 0 deletions torch/nn/quantizable/modules/__init__.py
@@ -0,0 +1,7 @@
from .rnn import LSTM
from .rnn import LSTMCell

__all__ = [
'LSTM',
'LSTMCell',
]