In [1]:
%matplotlib inline

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np

In [2]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn import quantized as nnq

## nn.Module implementation

In [3]:
class LSTMCell(nn.Module):
    """Single LSTM cell"""
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.input_size = input_dim
        self.hidden_size = hidden_dim

        self.igates = nn.Linear(input_dim, 4 * hidden_dim, bias=True)
        self.hgates = nn.Linear(hidden_dim, 4 * hidden_dim, bias=True)  # Maybe we don't need bias here
        self.gates = nnq.FloatFunctional()
        
        self.fgate_cx = nnq.FloatFunctional()
        self.igate_cgate = nnq.FloatFunctional()
        self.fgate_cx_igate_cgate = nnq.FloatFunctional()
        
        self.ogate_cy = nnq.FloatFunctional()
        
    def forward(self, x, hidden=None):  # (Batch, inputSize), ((Batch, hiddenSize), (Batch, hiddenSize))
        if hidden is None:
            hidden = self.initialize_hidden(x.shape[0], x.is_quantized)
        hx, cx = hidden

        igates = self.igates(x)  # (Batch, 4*hiddenSize)
        hgates = self.hgates(hx) # (Batch, 4*hiddenSize)
        gates = self.gates.add(igates, hgates)  # (Batch, 4*hiddenSize)
        
        input_gate, forget_gate, cell_gate, out_gate = gates.chunk(4, 1) # (Batch, hiddenSize) x 4
        
        input_gate = torch.sigmoid(input_gate)
        forget_gate = torch.sigmoid(forget_gate)
        cell_gate = torch.tanh(cell_gate)
        out_gate = torch.sigmoid(out_gate)
        
        fgate_cx = self.fgate_cx.mul(forget_gate, cx)  # (Batch, hiddenSize)
        igate_cgate = self.igate_cgate.mul(input_gate, cell_gate)  # (Batch, hiddenSize)
        fgate_cx_igate_cgate = self.fgate_cx_igate_cgate.add(fgate_cx, igate_cgate)  # (Batch, hiddenSize)
        cy = fgate_cx_igate_cgate
        
        tanh_cy = F.tanh(cy)
        hy = self.ogate_cy.mul(out_gate, tanh_cy)  # (Batch, hiddenSize)
        
        return hy, cy   # (Batch, hiddenSize), (Batch, hiddenSize)
    
    def initialize_hidden(self, batch_size, is_quantized):
        h, c = torch.zeros((batch_size, self.hidden_size)), torch.zeros((batch_size, self.hidden_size))
        if is_quantized:
            h = torch.quantize_per_tensor(h, scale=1.0, zero_point=0, dtype=torch.quint8)
            c = torch.quantize_per_tensor(c, scale=1.0, zero_point=0, dtype=torch.quint8)
        return h, c

In [4]:
B = 16
S = 128
iS = 7

x = torch.randn(B, iS)

lstm = LSTMCell(iS, 32)
hy, cy = lstm(x)
print(hy.shape, cy.shape)

torch.Size([16, 32]) torch.Size([16, 32])




In [5]:
class LSTMStack(nn.Module):
    def __init__(self, input_size, hidden_sizes, num_layers):
        """
        stack of lstms has num_layers cells
        """
        super().__init__()
        self.num_layers = num_layers
        if not isinstance(hidden_sizes, (list, tuple)):
            hidden_sizes = [hidden_sizes] * self.num_layers
        assert(len(hidden_sizes) == num_layers)
        self.input_size = input_size
        self.hidden_sizes = hidden_sizes
            
        cells = [LSTMCell(input_size, hidden_sizes[0])]
        
        for idx in range(1, self.num_layers):
            cells.append(LSTMCell(hidden_sizes[idx-1], hidden_sizes[idx]))
        self.cells = nn.ModuleList(cells)
        
    def forward(self, x, hidden=None):
        if hidden is None:
            hidden = self.initialize_hidden(x.shape[0], x.is_quantized)
        hx, cx = hidden

        for idx, cell in enumerate(self.cells):
            x, cy = cell(x, (hx[idx], cx[idx]))
            hx[idx] = x
            cx[idx] = cy
        return hx, cx
            
    def initialize_hidden(self, batch_size, quantized):
        hc = [ cell.initialize_hidden(batch_size, quantized) for cell in self.cells ]
        hx, cx = zip(*hc)
        return list(hx), list(cx)
        

In [6]:
B = 16
S = 128
iS = 7
num_layers = 3

x = torch.randn(B, iS)

lstm = LSTMStack(iS, 32, num_layers)
hy, cy = lstm(x)
print(len(hy), len(cy))
print(hy[0].shape, cy[0].shape)

3 3
torch.Size([16, 32]) torch.Size([16, 32])


In [7]:
class LSTM(nn.Module):
    def __init__(self, input_size, hidden_sizes, num_layers):
        super().__init__()
        self.cell_stack = LSTMStack(input_size, hidden_sizes, num_layers)
        
    def forward(self, x, hidden=None):
        if hidden is None:
            hidden = self.cell_stack.initialize_hidden(x.shape[1], x.is_quantized)
        seq_len = x.shape[0]
        y = []
        for idx in range(seq_len):
            hidden = self.cell_stack(x[idx], hidden)
            y.append(hidden[0][-1])
        return torch.stack(y, 0), hidden

In [8]:
B = 16
S = 128
iS = 7
num_layers = 3

x = torch.randn(S, B, iS)

lstm = LSTM(iS, 32, num_layers)
y, (hy, cy) = lstm(x)
print(x.shape, y.shape)
print(len(hy), len(cy))
print(hy[0].shape, cy[0].shape)

torch.Size([128, 16, 7]) torch.Size([128, 16, 32])
3 3
torch.Size([16, 32]) torch.Size([16, 32])


In [9]:
ref_lstm = nn.LSTM(iS, 32, num_layers)
y_ref = ref_lstm(x)
print(y_ref[0].shape)

torch.Size([128, 16, 32])


## Quantized version

In [10]:
import torch.quantization as tq

batch_size = 7
seq_len = 257
input_size = 31
hidden_size = 61
num_layers = 5

x = torch.randn(seq_len, batch_size, input_size)
qx = torch.quantize_per_tensor(x, scale=1e-2, zero_point=128, dtype=torch.quint8)

In [11]:
lstm = LSTM(input_size, hidden_size, num_layers)
y, (hy, cy) = lstm(x)
lstm.eval()

# 1. Prepare
lstm.qconfig = torch.quantization.default_qconfig
lstm_prepared = tq.prepare(lstm, inplace=False)

# 2. Calibrate
with torch.no_grad():
    lstm_prepared(x)

# 3. Convert
lstm_converted = tq.convert(lstm_prepared, inplace=False)

  reduce_range will be deprecated in a future release of PyTorch."


In [12]:
qy, hidden = lstm_converted(qx)
print(qy.shape)

torch.Size([257, 7, 61])


# FX

In [13]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(32, 32, kernel_size=1, padding=0)
        self.lstm = nn.LSTM(input_size=32, hidden_size=32, batch_first=False)
        
    def forward(self, x):
        x = self.conv(x)
        x = x.reshape(1, -1, 32)
        x = x.permute(1, 0, 2)
        x = self.lstm(x)[0]
        return x
    
def calibrate(model, calib_data):
    with torch.no_grad():
        for x in calib_data:
            model(x);

In [18]:
from torch import jit

model = jit.script(Model())
model.eval()
qm = tq.quantize_jit(model, {'': tq.default_qconfig}, calibrate, [[torch.rand(1, 32, 1, 10)]])

qm.inlined_graph

graph(%self : __torch__.___torch_mangle_49.Model,
      %x.2 : Tensor):
  %2 : float = prim::Constant[value=0.]() # /home/zafar/Git/pytorch-dev/pytorch-lstm/torch/nn/modules/rnn.py:582:30
  %3 : bool = prim::Constant[value=1]() # /home/zafar/Git/pytorch-dev/pytorch-lstm/torch/nn/modules/rnn.py:581:61
  %4 : int = prim::Constant[value=32]() # /home/zafar/Git/pytorch-dev/pytorch-lstm/torch/nn/modules/rnn.py:571:48
  %batch_sizes.1 : None = prim::Constant() # /home/zafar/Git/pytorch-dev/pytorch-lstm/torch/nn/modules/rnn.py:563:26
  %6 : str = prim::Constant[value="input.size(-1) must be equal to input_size. Expected {}, got {}"]() # /home/zafar/Git/pytorch-dev/pytorch-lstm/torch/nn/modules/rnn.py:179:16
  %7 : int = prim::Constant[value=-1]() # /home/zafar/Git/pytorch-dev/pytorch-lstm/torch/nn/modules/rnn.py:177:41
  %8 : str = prim::Constant[value="input must have {} dimensions, got {}"]() # /home/zafar/Git/pytorch-dev/pytorch-lstm/torch/nn/modules/rnn.py:175:16
  %9 : int = prim::Consta

In [28]:
from torch import fx
model = fx.symbolic_trace(Model())
model.eval()

# 1. Prepare
model_prepared = tq.prepare_fx(model, {'': tq.default_qconfig})

# 2. Calibrate
x = torch.randn(1, 32, 1, 10)
for _ in range(5):
    model_prepared(x);
model_converted = tq.convert_fx(model_prepared)
model_converted

GraphModuleImpl(
  (conv): QuantizedConv2d(32, 32, kernel_size=(1, 1), stride=(1, 1), scale=0.027948811650276184, zero_point=63)
  (lstm): LSTM(32, 32)
)