In [1]:
import sys
sys.path.append("..")
import torch
import torch.nn as nn
from models.networks.MLP2 import MLP2
from models.networks.MLP3 import MLP3
from models.networks.ConvertMLP import ConvertMLP2,ConvertMLP3
import numpy as np

The Zen of Python, by Tim Peters

Beautiful is better than ugly.
Explicit is better than implicit.
Simple is better than complex.
Complex is better than complicated.
Flat is better than nested.
Sparse is better than dense.
Readability counts.
Special cases aren't special enough to break the rules.
Although practicality beats purity.
Errors should never pass silently.
Unless explicitly silenced.
In the face of ambiguity, refuse the temptation to guess.
There should be one-- and preferably only one --obvious way to do it.
Although that way may not be obvious at first unless you're Dutch.
Now is better than never.
Although never is often better than *right* now.
If the implementation is hard to explain, it's a bad idea.
If the implementation is easy to explain, it may be a good idea.
Namespaces are one honking great idea -- let's do more of those!


## start from a trained model

In [3]:

model = MLP3(input_dim=(28,28),hidden_dim=64,output_dim=10).cpu().eval()
# get the model parameters
state = model.state_dict()

In [4]:
# convet the model to a quantizable model
c_model= ConvertMLP2(model)
# load previous parameter
c_model.load_state_dict(state)

<All keys matched successfully>

In [5]:
# set quantization config, can also set maunually
# see https://github.com/kredde/compression-robustness/blob/e71b2ac493577e8e4b6c7a414dd48850fdeefd63/src/experiments/static_quantization.py#L39
c_model.qconfig = torch.quantization.get_default_qconfig('fbgemm')

In [6]:
# fuse model
## to do compare fuse or not
c_model.fuse()



In [7]:
# use calibrate date to calibrate model
model_prepared = torch.quantization.prepare(c_model)
# _input should be some calibration data
_input = torch.randn(3,28,28)
model_prepared(_input)

  reduce_range will be deprecated in a future release of PyTorch."


tensor([[ 0.0389,  0.0415, -0.6710,  0.3382, -0.0209, -0.2361,  0.0457,  0.8677,
          0.5616, -1.1181],
        [ 0.0525,  0.4014, -0.6845,  0.6979, -0.0962,  0.1083, -0.5209,  0.0190,
         -0.1059,  0.7508],
        [ 0.9061,  0.1395,  0.0837,  0.3281, -0.0590,  0.5052, -1.0965,  0.5720,
          0.5648,  0.3750]], grad_fn=<AddmmBackward0>)

In [8]:
# convert to int8
model_int8 = torch.quantization.convert(model_prepared)

  src_bin_begin // dst_bin_width, 0, self.dst_nbins - 1
  src_bin_end // dst_bin_width, 0, self.dst_nbins - 1


In [9]:
res = model_int8(_input)
model_int8

ConvertMLP2(
  (layers): Sequential(
    (0): QuantizedLinearReLU(in_features=784, out_features=64, scale=0.020231977105140686, zero_point=0, qscheme=torch.per_channel_affine)
    (1): Identity()
    (2): QuantizedLinear(in_features=64, out_features=64, scale=0.02820894680917263, zero_point=71, qscheme=torch.per_channel_affine)
    (3): ReLU()
    (4): QuantizedLinear(in_features=64, out_features=10, scale=0.010910957120358944, zero_point=44, qscheme=torch.per_channel_affine)
  )
  (quant): Quantize(scale=tensor([0.0524]), zero_point=tensor([62]), dtype=torch.quint8)
  (dequant): DeQuantize()
)