diff --git a/test/quantization/test_quantized_op.py b/test/quantization/test_quantized_op.py index f1e52fc38d32..2249360d8bd5 100644 --- a/test/quantization/test_quantized_op.py +++ b/test/quantization/test_quantized_op.py @@ -2734,6 +2734,109 @@ def test_qrnncell(self, num_batches, input_size, hidden_size, per_channel_quant) self.assertEqual(result_ref[0], result_dynamic[0], msg="torch.quantized_rnncell results are off") +"""Tests recurrent layers.""" +class TestPTQRNN(TestCase): + def _snr(self, x, x_hat): + if isinstance(x, (list, tuple)): + assert(len(x) == len(x_hat)) + res = [] + for idx in range(len(x)): + res.append(self._snr(x[idx], x_hat[idx])) + return res + if x_hat.is_quantized: + x_hat = x_hat.dequantize() + noise = (x - x_hat).square().mean().item() + if noise == 0: + return 0.0, float('inf'), float('inf') + signal = x.square().mean().item() + snr = signal / noise + snr_db = 10 * np.log10(snr) + return signal, noise, snr_db + + @override_qengines + def test_lstm(self): + qengine = torch.backends.quantized.engine + + batch_size = 4 + seq_len = 8 + input_size = 12 + + hidden_size = 8 + num_layers = 2 + + dropout = 0 # This is not supported + + Bias = [False, True] + Batch_first = [False, True] + Bidirectional = [False, True] + + dtype = np.uint8 + qtype = torch.quint8 + + custom_module_config = { + 'float_to_observed_custom_module_class': { + torch.nn.LSTM: torch.nn.quantizable.LSTM + } + } + + # Moving the mean of the random input to avoid signal being close to 0 + x = np.random.randn(seq_len, batch_size, input_size) + scale, zero_point = _calculate_dynamic_qparams(x, dtype=dtype) + qx = _quantize(x, scale=scale, zero_point=zero_point, dtype=dtype) + x = torch.from_numpy(x).to(torch.float) + qx = torch.quantize_per_tensor(x, scale=scale, zero_point=zero_point, + dtype=qtype) + x = qx.dequantize() + + with torch.no_grad(): + for bias, batch_first, bidirectional in itertools.product( + Bias, Batch_first, Bidirectional): + # Assume 12dB is sufficient for functional equivalence + # Without the bias, linear performs poorly + min_power = 12 if bias else 6 + max_mse = 1e-6 if bias else 1e-1 + + if batch_first: + x = x.reshape(batch_size, seq_len, input_size) + qx = qx.reshape(batch_size, seq_len, input_size) + else: + x = x.reshape(seq_len, batch_size, input_size) + qx = qx.reshape(seq_len, batch_size, input_size) + + lstm = torch.nn.Sequential( + torch.nn.LSTM(input_size, hidden_size, + num_layers=num_layers, + bias=bias, batch_first=batch_first, + dropout=dropout, + bidirectional=bidirectional)) + lstm.eval() + y_ref = lstm(x) + + # Prepare + lstm.qconfig = torch.quantization.get_default_qconfig(qengine) + lstm_prepared = torch.quantization.prepare(lstm, + prepare_custom_config_dict=custom_module_config) + self.assertTrue(hasattr(lstm_prepared[0], 'layers')) + self.assertEqual(num_layers, len(lstm_prepared[0].layers)) + + # Calibrate + y = lstm_prepared(x) + self.assertEqual(y_ref, y) + + # Quantize + lstm_quantized = torch.quantization.convert(lstm_prepared) + qy = lstm_quantized(qx) + + snr = self._snr(y, qy) + snr = [snr[0]] + snr[1] + + for signal, mse, power in snr: + self.assertTrue( + power > min_power or mse < max_mse, + msg=(f"Error is too high: SNR(dB): {power}, " + f"Signal: {signal}, MSE: {mse}")) + + class TestQuantizedLinear(unittest.TestCase): """Tests the correctness of the quantized linear and linear_relu op.""" @given(batch_size=st.integers(1, 4), diff --git a/test/test_quantization.py b/test/test_quantization.py index 08399206257d..2b2adf7a07c2 100644 --- a/test/test_quantization.py +++ b/test/test_quantization.py @@ -14,6 +14,7 @@ from quantization.test_quantized_op import TestComparatorOps # noqa: F401 from quantization.test_quantized_op import TestPadding # noqa: F401 from quantization.test_quantized_op import TestQuantizedEmbeddingOps # noqa: F401 +from quantization.test_quantized_op import TestPTQRNN # noqa: F401 # Quantized Functional from quantization.test_quantized_functional import TestQuantizedFunctional # noqa: F401 diff --git a/torch/__init__.py b/torch/__init__.py index 04955623ab2a..9ae1010a3ba8 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -574,6 +574,7 @@ def _assert(condition, message): import torch.futures import torch.nn import torch.nn.intrinsic +import torch.nn.quantizable import torch.nn.quantized import torch.optim import torch.optim._multi_tensor diff --git a/torch/nn/quantizable/__init__.py b/torch/nn/quantizable/__init__.py new file mode 100644 index 000000000000..270dcebaa5f4 --- /dev/null +++ b/torch/nn/quantizable/__init__.py @@ -0,0 +1 @@ +from .modules import * diff --git a/torch/nn/quantizable/modules/__init__.py b/torch/nn/quantizable/modules/__init__.py new file mode 100644 index 000000000000..2e2f4de995f8 --- /dev/null +++ b/torch/nn/quantizable/modules/__init__.py @@ -0,0 +1,7 @@ +from .rnn import LSTM +from .rnn import LSTMCell + +__all__ = [ + 'LSTM', + 'LSTMCell', +] diff --git a/torch/nn/quantizable/modules/rnn.py b/torch/nn/quantizable/modules/rnn.py new file mode 100644 index 000000000000..29b5de42b67b --- /dev/null +++ b/torch/nn/quantizable/modules/rnn.py @@ -0,0 +1,291 @@ +import numbers +from typing import Optional, Tuple + +import torch +from torch import Tensor + +""" +We will recreate all the RNN modules as we require the modules to be decomposed +into its building blocks to be able to observe. +""" + +class LSTMCell(torch.nn.Module): + _FLOAT_MODULE = torch.nn.LSTMCell + # There are multiple outputs in the forward -- we don't want to observe! + do_not_observe: bool = True + + def __init__(self, input_dim, hidden_dim, bias=True): + super().__init__() + self.input_size = input_dim + self.hidden_size = hidden_dim + self.bias = bias + + self.igates = torch.nn.Linear(input_dim, 4 * hidden_dim, bias=bias) + self.hgates = torch.nn.Linear(hidden_dim, 4 * hidden_dim, bias=bias) + self.gates = torch.nn.quantized.FloatFunctional() + + self.fgate_cx = torch.nn.quantized.FloatFunctional() + self.igate_cgate = torch.nn.quantized.FloatFunctional() + self.fgate_cx_igate_cgate = torch.nn.quantized.FloatFunctional() + + self.ogate_cy = torch.nn.quantized.FloatFunctional() + + def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tensor]: + if hidden is None or hidden == (None, None): + hidden = self.initialize_hidden(x.shape[0], x.is_quantized) + hx, cx = hidden + + igates = self.igates(x) + hgates = self.hgates(hx) + gates = self.gates.add(igates, hgates) + + input_gate, forget_gate, cell_gate, out_gate = gates.chunk(4, 1) + + input_gate = torch.sigmoid(input_gate) + forget_gate = torch.sigmoid(forget_gate) + cell_gate = torch.tanh(cell_gate) + out_gate = torch.sigmoid(out_gate) + + fgate_cx = self.fgate_cx.mul(forget_gate, cx) + igate_cgate = self.igate_cgate.mul(input_gate, cell_gate) + fgate_cx_igate_cgate = self.fgate_cx_igate_cgate.add(fgate_cx, igate_cgate) + cy = fgate_cx_igate_cgate + + tanh_cy = torch.tanh(cy) + hy = self.ogate_cy.mul(out_gate, tanh_cy) + return hy, cy + + def initialize_hidden(self, batch_size: int, is_quantized: bool = False) -> Tuple[Tensor, Tensor]: + h, c = torch.zeros((batch_size, self.hidden_size)), torch.zeros((batch_size, self.hidden_size)) + if is_quantized: + h = torch.quantize_per_tensor(h, scale=1.0, zero_point=0, dtype=torch.quint8) + c = torch.quantize_per_tensor(c, scale=1.0, zero_point=0, dtype=torch.quint8) + return h, c + + @classmethod + def from_params(cls, wi, wh, bi=None, bh=None): + assert (bi is None) == (bh is None) # Either both None or both have values + input_size = wi.shape[1] + hidden_size = wh.shape[1] + cell = cls(input_dim=input_size, hidden_dim=hidden_size, + bias=(bi is not None)) + cell.igates.weight = torch.nn.Parameter(wi) + if bi is not None: + cell.igates.bias = torch.nn.Parameter(bi) + cell.hgates.weight = torch.nn.Parameter(wh) + if bh is not None: + cell.hgates.bias = torch.nn.Parameter(bh) + return cell + + @classmethod + def from_float(cls, other): + assert hasattr(other, 'qconfig'), "The float module must have 'qconfig'" + observed = cls.from_params(other.weight_ih, other.weight_hh, + other.bias_ih, other.bias_hh) + observed.qconfig = other.qconfig + observed.igates.qconfig = other.qconfig + observed.hgates.qconfig = other.qconfig + return observed + + +class _LSTMSingleLayer(torch.nn.Module): + def __init__(self, input_dim, hidden_dim, bias=True): + super().__init__() + self.cell = LSTMCell(input_dim, hidden_dim, bias=bias) + + def forward(self, x, hidden=None): + result = [] + for xx in x: + hidden = self.cell(xx, hidden) + result.append(hidden[0]) + result = torch.stack(result, 0) + return result, hidden + + @classmethod + def from_params(cls, *args, **kwargs): + cell = LSTMCell.from_params(*args, **kwargs) + layer = cls(cell.input_size, cell.hidden_size, cell.bias) + layer.cell = cell + return layer + + +class _LSTMLayer(torch.nn.Module): + def __init__(self, input_dim, hidden_dim, bias=True, batch_first=False, + bidirectional=False): + super().__init__() + self.batch_first = batch_first + self.bidirectional = bidirectional + self.layer_fw = _LSTMSingleLayer(input_dim, hidden_dim, bias=bias) + if self.bidirectional: + self.layer_bw = _LSTMSingleLayer(input_dim, hidden_dim, bias=bias) + + def forward(self, x, hidden=None): + if self.batch_first: + x = x.transpose(0, 1) + if hidden is None: + hidden = (None, None) + hx_fw, cx_fw = hidden + if self.bidirectional: + if hx_fw is None: + hx_fw = [None, None] + if cx_fw is None: + cx_fw = [None, None] + hx_bw = hx_fw[1] + hx_fw = hx_fw[0] + cx_bw = cx_fw[1] + cx_fw = cx_fw[0] + hidden_bw = hx_bw, cx_bw + hidden_fw = hx_fw, cx_fw + result_fw, hidden_fw = self.layer_fw(x, hidden_fw) + + if self.bidirectional: + x_reversed = x.flip(0) + result_bw, hidden_bw = self.layer_bw(x_reversed, hidden_bw) + result_bw = result_bw.flip(0) + + result = torch.cat([result_fw, result_bw], result_fw.dim() - 1) + h = torch.stack([hidden_fw[0], hidden_bw[0]], 0) + c = torch.stack([hidden_fw[1], hidden_bw[1]], 0) + else: + result = result_fw + h, c = hidden_fw + + if self.batch_first: + result.transpose_(0, 1) + + return result, (h, c) + + @classmethod + def from_float(cls, other, layer_idx=0, qconfig=None, **kwargs): + assert hasattr(other, 'qconfig') or (qconfig is not None) + + input_size = kwargs.get('input_size', other.input_size) + hidden_size = kwargs.get('hidden_size', other.hidden_size) + bias = kwargs.get('bias', other.bias) + batch_first = kwargs.get('batch_first', other.batch_first) + bidirectional = kwargs.get('bidirectional', other.bidirectional) + + layer = cls(input_size, hidden_size, bias, batch_first, bidirectional) + layer.qconfig = getattr(other, 'qconfig', qconfig) + wi = getattr(other, f'weight_ih_l{layer_idx}') + wh = getattr(other, f'weight_hh_l{layer_idx}') + bi = getattr(other, f'bias_ih_l{layer_idx}', None) + bh = getattr(other, f'bias_hh_l{layer_idx}', None) + + layer.layer_fw = _LSTMSingleLayer.from_params(wi, wh, bi, bh) + + if other.bidirectional: + wi = getattr(other, f'weight_ih_l{layer_idx}_reverse') + wh = getattr(other, f'weight_hh_l{layer_idx}_reverse') + bi = getattr(other, f'bias_ih_l{layer_idx}_reverse', None) + bh = getattr(other, f'bias_hh_l{layer_idx}_reverse', None) + layer.layer_bw = _LSTMSingleLayer.from_params(wi, wh, bi, bh) + return layer + + +class LSTM(torch.nn.Module): + _FLOAT_MODULE = torch.nn.LSTM + # There are multiple outputs in the forward -- we don't want to observe! + do_not_observe: bool = True + + def __init__(self, input_size: int, hidden_size: int, + num_layers: int = 1, bias: bool = True, batch_first: bool = False, + dropout: float = 0., bidirectional: bool = False): + super().__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.num_layers = num_layers + self.bias = bias + self.batch_first = batch_first + self.dropout = float(dropout) + self.bidirectional = bidirectional + self.training = False # We don't want to train using this module + num_directions = 2 if bidirectional else 1 + + if not isinstance(dropout, numbers.Number) or not 0 <= dropout <= 1 or \ + isinstance(dropout, bool): + raise ValueError("dropout should be a number in range [0, 1] " + "representing the probability of an element being " + "zeroed") + if dropout > 0: + warnings.warn("dropout option for quantizable LSTM is ignored. " + "If you are training, please, use nn.LSTM version " + "followed by `prepare` step.") + if num_layers == 1: + warnings.warn("dropout option adds dropout after all but last " + "recurrent layer, so non-zero dropout expects " + "num_layers greater than 1, but got dropout={} " + "and num_layers={}".format(dropout, num_layers)) + + layers = [_LSTMLayer(self.input_size, self.hidden_size, + self.bias, batch_first=False, + bidirectional=self.bidirectional)] + for layer in range(1, num_layers): + layers.append(_LSTMLayer(self.hidden_size, self.hidden_size, + self.bias, batch_first=False, + bidirectional=self.bidirectional)) + self.layers = torch.nn.ModuleList(layers) + + def forward(self, x, hidden=None): + if self.batch_first: + x = x.transpose(0, 1) + + max_batch_size = x.size(1) + num_directions = 2 if self.bidirectional else 1 + if hidden is None: + zeros = torch.zeros(num_directions, max_batch_size, + self.hidden_size, dtype=torch.float, + device=x.device) + zeros.squeeze_(0) + if x.is_quantized: + zeros = torch.quantize_per_tensor(zeros, scale=1.0, + zero_point=0, dtype=x.dtype) + hidden = [(zeros, zeros) for _ in range(self.num_layers)] + elif isinstance(hidden[0], Tensor): + hx = hidden[0].reshape(self.num_layers, num_directions, + max_batch_size, self.hidden_size).unbind(0) + cx = hidden[1].reshape(self.num_layers, num_directions, + max_batch_size, self.hidden_size).unbind(0) + hidden = [] + for idx in range(self.num_layers): + hidden.append((hx[idx].squeeze_(0), cx[idx].squeeze_(0))) + + for idx in range(self.num_layers): + x, hidden[idx] = self.layers[idx](x, hidden[idx]) + + hx = [] + cx = [] + for idx in range(self.num_layers): + hx.append(hidden[idx][0]) + cx.append(hidden[idx][1]) + hx = torch.stack(hx) + cx = torch.stack(cx) + + # We are creating another dimension for bidirectional case + # need to collapse it + hx = hx.reshape(-1, *hx.shape[-2:]) + cx = cx.reshape(-1, *hx.shape[-2:]) + + if self.batch_first: + x = x.transpose(0, 1) + + return x, (hx, cx) + + @classmethod + def from_float(cls, other, qconfig=None): + assert isinstance(other, cls._FLOAT_MODULE) + assert (hasattr(other, 'qconfig') or qconfig) + observed = cls(other.input_size, other.hidden_size, other.num_layers, + other.bias, other.batch_first, other.dropout, + other.bidirectional) + observed.qconfig = getattr(other, 'qconfig', qconfig) + for idx in range(other.num_layers): + observed.layers[idx] = _LSTMLayer.from_float(other, idx, qconfig, + batch_first=False) + observed.eval() + observed = torch.quantization.prepare(observed, inplace=True) + return observed + + def from_observed(self, other): + return torch.quantization.convert(self, inplace=False, + remove_qconfig=True)