In [36]:
# long short term memory
# a special recurrent neural net
# brings persistence in inferencing. we do not start from scratch

In [37]:
# quantization allows one make a trade-off btn performance and accuracy with a known models after
# its training is complete

# Quantization -> instantiate a floating point model and create a quantized version of it

In [2]:
#import the modules to be used in the recpe

import torch
import torch.quantization
import torch.nn as nn
import copy
import os
import time

In [3]:
class lstm_for_demo(nn.Module):
    """
        Elementary long short term memory style.
        Wraps up an nn.LSTM
    """

    def __init__(self, in_dim, out_dim, depth):
        super(lstm_for_demo, self).__init__()
        self.lstm = nn.LSTM(in_dim, out_dim, depth)

    def forward(self, inputs, hidden):
        out, hidden = self.lstm(inputs, hidden)
        return out, hidden

torch.manual_seed(29592) #seeds for reproducibility

#shape params
model_dimension=800
sequence_length = 20
batch_size = 1
lstm_depth = 1

#random data for input
inputs = torch.randn(sequence_length, batch_size, model_dimension)
#hidden is actually a tuple of the inintial hidden state and the initial cell state
hidden = (torch.randn(lstm_depth, batch_size, model_dimension), 
torch.randn(lstm_depth, batch_size, model_dimension))

Quantization

In [4]:
# create a floating point instance
float_lstm = lstm_for_demo(model_dimension, model_dimension, lstm_depth)

# quantizing -> returns a quantized version of the module
quantized_lstm = torch.quantization.quantize_dynamic(
    float_lstm, {nn.LSTM, nn.Linear}, dtype=torch.qint8
)

print(f"floating-point version of the module {float_lstm}")
print(f'quantized version {quantized_lstm}')

floating-point version of the module lstm_for_demo(
  (lstm): LSTM(800, 800)
)
quantized version lstm_for_demo(
  (lstm): DynamicQuantizedLSTM(800, 800)
)


Model Size

In [5]:
def print_size_of_model(model, label=" "):
    torch.save(model.state_dict(), "temp.p")
    size=os.path.getsize("temp.p")
    print(f"model: {label} \t size(kb): {size/1e3}" )
    os.remove("temp.p")
    return size

#comparing the sizes
f = print_size_of_model(float_lstm, "fp32")
q = print_size_of_model(quantized_lstm, "int8")
print("{0:.2f} times smaller".format(f/q))

model: fp32 	 size(kb): 20507.039
model: int8 	 size(kb): 5147.551
3.98 times smaller


Accuracy

In [6]:
# run the float modesl
out1, hidden1 = float_lstm(inputs, hidden)
mag1 = torch.mean(abs(out1)).item()
print("Mean abs value of output tensor values in fp32 model is {0:.5f}".format(mag1))

# run the quant model
out2, hidden2 = quantized_lstm(inputs, hidden)
mag2 = torch.mean(abs(out2)).item()
print("Mean abs value of output tensor values in int8 model is {0:.5f}".format(mag2))

Mean abs value of output tensor values in fp32 model is 0.11407
Mean abs value of output tensor values in int8 model is 0.11412


Latency: tend to run faster -> int8 ops/ less time moving param data

In [7]:
# compare the performance
print("Floating Point")
%timeit float_lstm.forward(inputs,hidden)

print("Quant INT8")
%timeit quantized_lstm.forward(inputs,hidden)


Floating Point
18 ms ± 1.47 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Quant INT8
2.21 ms ± 231 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
