In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

In [3]:
import torch
from torch import nn
import torch.nn.functional as F

import torch.quantization as tq
from torch.nn import quantized as nnq

In [33]:
def quantize(x, dtype):
    iinfo = torch.iinfo(dtype)
    qmin = iinfo.min
    qmax = iinfo.max
    fmin = min(0, x.min().item())
    fmax = max(0, x.max().item())
    
    if fmin == fmax:
        fdiff = 1.0
    else:
        fdiff = (fmax - fmin)
    
    scale = fdiff / (qmax - qmin)
    zp = int(round(qmin - fmin / scale))
    qx = torch.quantize_per_tensor(x, scale=scale, zero_point=zp, dtype=dtype)
    return qx

def SNR(x, qx):
    if qx.is_quantized:
        qx = qx.dequantize()
    noise = ((x - qx)**2).mean().item()
    signal = x.square().mean().item()
    snr = signal / noise
    snr_db = 10 * np.log10(snr)
    return snr, snr_db    

In [34]:
from pose_mobilenet import get_pose_net

model = get_pose_net(is_train=False)
model.eval()
model.fuse_model()
model.qconfig = tq.get_default_qconfig('qnnpack')

In [35]:
# from torchsummary import summary
# summary(model, (3, 192, 256), device='cpu')
model

PoseMobileNet(
  (features): Sequential(
    (0): ConvBNReLU(
      (0): ConvReLU2d(
        (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (1): ReLU(inplace=True)
      )
      (1): Identity()
      (2): Identity()
    )
    (1): InvertedResidual(
      (conv): Sequential(
        (0): ConvBNReLU(
          (0): ConvReLU2d(
            (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)
            (1): ReLU(inplace=True)
          )
          (1): Identity()
          (2): Identity()
        )
        (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1))
        (2): Identity()
      )
      (skip_add): FloatFunctional(
        (activation_post_process): Identity()
      )
    )
    (2): InvertedResidual(
      (conv): Sequential(
        (0): ConvBNReLU(
          (0): ConvReLU2d(
            (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1))
            (1): ReLU(inplace=True)
          )
          (1): Identity()
  

In [38]:
N = 16
C = 3
H, W = 192, 256
x = torch.randn(N, C, H, W)
qx = quantize(x, torch.quint8)

SNR(x, qx)

(7934.060356838403, 38.994954999624525)

In [39]:
import copy

# Create quantized model
model_prepared = tq.prepare(model, inplace=False)
model_calibrated = copy.deepcopy(model_prepared)
model_calibrated(x);
model_converted = tq.convert(model_calibrated, inplace=False)

In [40]:
model_converted

PoseMobileNet(
  (features): Sequential(
    (0): ConvBNReLU(
      (0): QuantizedConvReLU2d(3, 32, kernel_size=(3, 3), stride=(2, 2), scale=0.009284310974180698, zero_point=0, padding=(1, 1))
      (1): Identity()
      (2): Identity()
    )
    (1): InvertedResidual(
      (conv): Sequential(
        (0): ConvBNReLU(
          (0): QuantizedConvReLU2d(32, 32, kernel_size=(3, 3), stride=(1, 1), scale=0.004920272622257471, zero_point=0, padding=(1, 1), groups=32)
          (1): Identity()
          (2): Identity()
        )
        (1): QuantizedConv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), scale=0.0030450373888015747, zero_point=113)
        (2): Identity()
      )
      (skip_add): QFunctional(
        scale=1.0, zero_point=0
        (activation_post_process): Identity()
      )
    )
    (2): InvertedResidual(
      (conv): Sequential(
        (0): ConvBNReLU(
          (0): QuantizedConvReLU2d(16, 96, kernel_size=(1, 1), stride=(1, 1), scale=0.001046488992869854, zero_point=0)


In [44]:
y = model(x)
qy = model_converted(x)

print("SNR: {}, SNR(dB): {}".format(*SNR(y, qy)))

SNR: 72752.96636753932, SNR(dB): 48.6185070557258
