Skip to content

Commit

Permalink
Update on "[quant] Quantizable LSTM"
Browse files Browse the repository at this point in the history
- Introduces the `torch.nn.quantizable` namespace
- Adds the `torch.nn.quantizable.LSTM` module

The point of the `quantizable` namespace is to segregate the purely quantized modules with the modules that could be quantized through a normal quantization flow, but are not using the quantized kernels explicitly.
That means the quantizable modules are functionally and numerically equivalent to the FP ones and can be used instead of the FP ones without any loss.

The main difference between the `torch.nn.LSTM` and the `torch.nn.quantizable.LSTM` is that the former one does not support observation for the linear layers, because all the computation is internal to the `aten` namespace.
The `torch.nn.quantizable.LSTM`, however, uses explicit linear layers that can be observed for further quantization.

Test Plan:

```
python test/test_quantization.py TestQuantizedOps.test_custom_module_lstm
```

Differential Revision: [D25663870](https://our.internmc.facebook.com/intern/diff/D25663870)

[ghstack-poisoned]
  • Loading branch information
z-a-f committed Dec 30, 2020
2 parents c2f989e + bdd4542 commit 8d85aeb
Showing 1 changed file with 142 additions and 33 deletions.
175 changes: 142 additions & 33 deletions torch/nn/quantizable/modules/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,25 @@
"""

class LSTMCell(torch.nn.Module):
r"""A quantizable long short-term memory (LSTM) cell.
For the description and the argument types, please, refer to :class:`~torch.nn.LSTMCell`
Examples::
>>> import torch.nn.quantizable as nnqa
>>> rnn = nnqa.LSTMCell(10, 20)
>>> input = torch.randn(3, 10)
>>> hx = torch.randn(3, 20)
>>> cx = torch.randn(3, 20)
>>> output = []
>>> for i in range(6):
hx, cx = rnn(input[i], (hx, cx))
output.append(hx)
"""
_FLOAT_MODULE = torch.nn.LSTMCell

def __init__(self, input_dim, hidden_dim, bias=True):
def __init__(self, input_dim: int, hidden_dim: int, bias: bool = True):
super().__init__()
self.input_size = input_dim
self.hidden_size = hidden_dim
Expand Down Expand Up @@ -66,6 +82,12 @@ def _get_name(self):

@classmethod
def from_params(cls, wi, wh, bi=None, bh=None):
"""Uses the weights and biases to create a new LSTM cell.
Args:
wi, wh: Weights for the input and hidden layers
bi, bh: Biases for the input and hidden layers
"""
assert (bi is None) == (bh is None) # Either both None or both have values
input_size = wi.shape[1]
hidden_size = wh.shape[1]
Expand All @@ -81,6 +103,7 @@ def from_params(cls, wi, wh, bi=None, bh=None):

@classmethod
def from_float(cls, other):
assert type(other) == cls._FLOAT_MODULE
assert hasattr(other, 'qconfig'), "The float module must have 'qconfig'"
observed = cls.from_params(other.weight_ih, other.weight_hh,
other.bias_ih, other.bias_hh)
Expand All @@ -91,15 +114,20 @@ def from_float(cls, other):


class _LSTMSingleLayer(torch.nn.Module):
def __init__(self, input_dim, hidden_dim, bias=True):
r"""A single one-directional LSTM layer.
The difference between a layer and a cell is that the layer can process a
sequence, while the cell only expects an instantaneous value.
"""
def __init__(self, input_dim: int, hidden_dim: int, bias: bool = True):
super().__init__()
self.cell = LSTMCell(input_dim, hidden_dim, bias=bias)

def forward(self, x, hidden=None):
def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None):
result = []
for xx in x:
hidden = self.cell(xx, hidden)
result.append(hidden[0])
result.append(hidden[0]) # type: ignore
result_tensor = torch.stack(result, 0)
return result_tensor, hidden

Expand All @@ -112,30 +140,34 @@ def from_params(cls, *args, **kwargs):


class _LSTMLayer(torch.nn.Module):
def __init__(self, input_dim, hidden_dim, bias=True, batch_first=False,
bidirectional=False):
r"""A single bi-directional LSTM layer."""
def __init__(self, input_dim: int, hidden_dim: int, bias: bool = True,
batch_first: bool = False, bidirectional: bool = False):
super().__init__()
self.batch_first = batch_first
self.bidirectional = bidirectional
self.layer_fw = _LSTMSingleLayer(input_dim, hidden_dim, bias=bias)
if self.bidirectional:
self.layer_bw = _LSTMSingleLayer(input_dim, hidden_dim, bias=bias)

def forward(self, x, hidden=None):
def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None):
if self.batch_first:
x = x.transpose(0, 1)
if hidden is None:
hidden = (None, None)
hx_fw, cx_fw = hidden
hx_fw, cx_fw = (None, None)
else:
hx_fw, cx_fw = hidden
if self.bidirectional:
if hx_fw is None:
hx_fw = [None, None]
hx_bw = None
else:
hx_bw = hx_fw[1]
hx_fw = hx_fw[0]
if cx_fw is None:
cx_fw = [None, None]
hx_bw = hx_fw[1]
hx_fw = hx_fw[0]
cx_bw = cx_fw[1]
cx_fw = cx_fw[0]
cx_bw = None
else:
cx_bw = cx_fw[1]
cx_fw = cx_fw[0]
hidden_bw = hx_bw, cx_bw
hidden_fw = hx_fw, cx_fw
result_fw, hidden_fw = self.layer_fw(x, hidden_fw)
Expand All @@ -146,11 +178,11 @@ def forward(self, x, hidden=None):
result_bw = result_bw.flip(0)

result = torch.cat([result_fw, result_bw], result_fw.dim() - 1)
h = torch.stack([hidden_fw[0], hidden_bw[0]], 0)
c = torch.stack([hidden_fw[1], hidden_bw[1]], 0)
h = torch.stack([hidden_fw[0], hidden_bw[0]], 0) # type: ignore
c = torch.stack([hidden_fw[1], hidden_bw[1]], 0) # type: ignore
else:
result = result_fw
h, c = hidden_fw
h, c = hidden_fw # type: ignore

if self.batch_first:
result.transpose_(0, 1)
Expand All @@ -159,6 +191,11 @@ def forward(self, x, hidden=None):

@classmethod
def from_float(cls, other, layer_idx=0, qconfig=None, **kwargs):
r"""
There is no FP equivalent of this class. This function is here just to
mimic the behavior of the `prepare` within the `torch.quantization`
flow.
"""
assert hasattr(other, 'qconfig') or (qconfig is not None)

input_size = kwargs.get('input_size', other.input_size)
Expand All @@ -184,13 +221,79 @@ def from_float(cls, other, layer_idx=0, qconfig=None, **kwargs):
layer.layer_bw = _LSTMSingleLayer.from_params(wi, wh, bi, bh)
return layer

# Getters for the weights and biases
# Note that jit currently doesn't support the `porperty`, so if you need to
# access the weights/biases you would need to navigate manually to the
# `layer_fw.cell.igates.*`: https://github.com/pytorch/pytorch/issues/37883
@property
def weight_ih(self):
return self.layer_fw.cell.igates.weight

@property
def weight_hh(self):
return self.layer_fw.cell.hgates.weight

@property
def bias_ih(self):
return self.layer_fw.cell.igates.bias

@property
def bias_hh(self):
return self.layer_fw.cell.hgates.bias

@property
def weight_ih_reverse(self):
assert self.bidirectional, 'There is no reverse path in the non-bidirectional layer'
return self.layer_bw.cell.igates.weight

@property
def weight_hh_reverse(self):
assert self.bidirectional, 'There is no reverse path in the non-bidirectional layer'
return self.layer_bw.cell.hgates.weight

@property
def bias_ih_reverse(self):
assert self.bidirectional, 'There is no reverse path in the non-bidirectional layer'
return self.layer_bw.cell.igates.bias

@property
def bias_hh_reverse(self):
assert self.bidirectional, 'There is no reverse path in the non-bidirectional layer'
return self.layer_bw.cell.hgates.bias


class LSTM(torch.nn.Module):
r"""A quantizable long short-term memory (LSTM).
For the description and the argument types, please, refer to :class:`~torch.nn.LSTM`
Attributes:
layers : instances of the `_LSTMLayer`
.. note::
To access the weights and biases, you need to access them per layer.
See examples below.
Examples::
>>> import torch.nn.quantizable as nnqa
>>> rnn = nnqa.LSTM(10, 20, 2)
>>> input = torch.randn(5, 3, 10)
>>> h0 = torch.randn(2, 3, 20)
>>> c0 = torch.randn(2, 3, 20)
>>> output, (hn, cn) = rnn(input, (h0, c0))
>>> # To get the weights:
>>> print(rnn.layers[0].weight_ih)
tensor([[...]])
>>> print(rnn.layers[0].weight_hh)
AssertionError: There is no reverse path in the non-bidirectional layer
"""
_FLOAT_MODULE = torch.nn.LSTM

def __init__(self, input_size: int, hidden_size: int,
num_layers: int = 1, bias: bool = True, batch_first: bool = False,
dropout: float = 0., bidirectional: bool = False):
num_layers: int = 1, bias: bool = True,
batch_first: bool = False, dropout: float = 0.,
bidirectional: bool = False):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
Expand Down Expand Up @@ -226,7 +329,7 @@ def __init__(self, input_size: int, hidden_size: int,
bidirectional=self.bidirectional))
self.layers = torch.nn.ModuleList(layers)

def forward(self, x, hidden=None):
def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None):
if self.batch_first:
x = x.transpose(0, 1)

Expand All @@ -240,24 +343,30 @@ def forward(self, x, hidden=None):
if x.is_quantized:
zeros = torch.quantize_per_tensor(zeros, scale=1.0,
zero_point=0, dtype=x.dtype)
hidden = [(zeros, zeros) for _ in range(self.num_layers)]
elif isinstance(hidden[0], Tensor):
hx = hidden[0].reshape(self.num_layers, num_directions,
max_batch_size, self.hidden_size).unbind(0)
cx = hidden[1].reshape(self.num_layers, num_directions,
max_batch_size, self.hidden_size).unbind(0)
hidden = []
for idx in range(self.num_layers):
hidden.append((hx[idx].squeeze_(0), cx[idx].squeeze_(0)))
hxcx = [(zeros, zeros) for _ in range(self.num_layers)]
else:
hidden_non_opt = torch.jit._unwrap_optional(hidden)
if isinstance(hidden_non_opt[0], Tensor):
hx = hidden_non_opt[0].reshape(self.num_layers, num_directions,
max_batch_size,
self.hidden_size).unbind(0)
cx = hidden_non_opt[1].reshape(self.num_layers, num_directions,
max_batch_size,
self.hidden_size).unbind(0)
hxcx = []
for idx in range(self.num_layers):
hxcx.append((hx[idx].squeeze_(0), cx[idx].squeeze_(0)))
else:
hxcx = hidden_non_opt

for idx in range(self.num_layers):
x, hidden[idx] = self.layers[idx](x, hidden[idx])
x, hxcx[idx] = self.layers[idx](x, hxcx[idx])

hx_list = []
cx_list = []
for idx in range(self.num_layers):
hx_list.append(hidden[idx][0])
cx_list.append(hidden[idx][1])
hx_list.append(hxcx[idx][0])
cx_list.append(hxcx[idx][1])
hx_tensor = torch.stack(hx_list)
cx_tensor = torch.stack(cx_list)

Expand Down

0 comments on commit 8d85aeb

Please sign in to comment.