In [10]:
# Import modules
import os
import cv2
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

In [11]:
model_path = os.getcwd() + '/model/best_model.pth'
NUM_CLASSES = 2

# Set device
device = torch.device("cpu")  # "cpu", cuda:0"
print(device)

cpu


In [12]:
class Net(nn.Module):
    
    def __init__(self):
        super().__init__() 
        self.quant = torch.quantization.QuantStub()
        self.num_classes = NUM_CLASSES
        self.conv1 = nn.Conv2d(3, 32, 5) 
        self.conv2 = nn.Conv2d(32, 64, 3)
        self.conv3 = nn.Conv2d(64, 128, 3)
        
        x = torch.randn(3,224,224).view(-1,3,224,224)
        self._to_linear = None
        self.convs(x)

        self.fc1 = nn.Linear(self._to_linear, 10)
        self.fc2 = nn.Linear(10, NUM_CLASSES)
        self.dequant = torch.quantization.DeQuantStub()

    def convs(self, x):
        # max pooling over 2x2
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv3(x)), (2, 2))
        
        if self._to_linear is None:
            self._to_linear = x[0].shape[0]*x[0].shape[1]*x[0].shape[2]
        return x

    def forward(self, x):
        x = self.quant(x)
        x = self.convs(x)
        x = x.reshape(-1, self._to_linear) 
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        x = self.dequant(x)
        if self.num_classes == 2:
            return F.sigmoid(x)
        else:
            return F.softmax(x, dim=1)

In [13]:
model = Net()
state_dict = torch.load(model_path)
model.load_state_dict(state_dict)
model.to(device)
model.eval()

Net(
  (quant): QuantStub()
  (conv1): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (conv3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
  (fc1): Linear(in_features=86528, out_features=10, bias=True)
  (fc2): Linear(in_features=10, out_features=2, bias=True)
  (dequant): DeQuantStub()
)

In [14]:
# Original model size
print("%.2f MB" %(os.path.getsize(model_path)/1e6))

3.84 MB


In [15]:
# Quantization
backend = "fbgemm"  # fbgemm, qnnpack
model.qconfig = torch.quantization.get_default_qconfig(backend)
torch.backends.quantized.engine = backend
model_static_quantized = torch.quantization.prepare(model, inplace=False)
input_ = torch.randn(1, 3, 224, 224)
model_static_quantized(input_)
model_static_quantized = torch.quantization.convert(model_static_quantized, inplace=False)

In [8]:
def print_model_size(mdl):
    torch.save(mdl.state_dict(), os.path.join(os.getcwd(), "model", "quant_model.pth"))
    print("%.2f MB" %(os.path.getsize(os.path.join(os.getcwd(), "model", "quant_model.pth"))/1e6))
    #os.remove('tmp.pt')

In [9]:
print_model_size(model_static_quantized) 

0.97 MB


In [10]:
print(model_static_quantized)

Net(
  (quant): Quantize(scale=tensor([0.0315]), zero_point=tensor([122]), dtype=torch.quint8)
  (conv1): QuantizedConv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), scale=0.02010815218091011, zero_point=142)
  (conv2): QuantizedConv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.016971413046121597, zero_point=109)
  (conv3): QuantizedConv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), scale=0.04692097380757332, zero_point=67)
  (fc1): QuantizedLinear(in_features=86528, out_features=10, scale=5.125334739685059, zero_point=25, qscheme=torch.per_tensor_affine)
  (fc2): QuantizedLinear(in_features=10, out_features=2, scale=3.220567226409912, zero_point=144, qscheme=torch.per_tensor_affine)
  (dequant): DeQuantize()
)


In [11]:
# Test loading
model = Net()
backend = "fbgemm"
model.qconfig = torch.quantization.get_default_qconfig(backend)
model_prepared = torch.quantization.prepare(model, inplace=False)
input_ = torch.randn(1, 3, 224, 224)
model_prepared(input_)
quant_model = torch.quantization.convert(model_prepared) 
    
state_dict = torch.load(os.getcwd() + '/model/quant_model_qnn.pth')
quant_model.load_state_dict(state_dict, strict=False)
quant_model.eval()

Net(
  (quant): Quantize(scale=tensor([0.0315]), zero_point=tensor([122]), dtype=torch.quint8)
  (conv1): QuantizedConv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), scale=0.02010815218091011, zero_point=142)
  (conv2): QuantizedConv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.016971413046121597, zero_point=109)
  (conv3): QuantizedConv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), scale=0.04692097380757332, zero_point=67)
  (fc1): QuantizedLinear(in_features=86528, out_features=10, scale=5.125334739685059, zero_point=25, qscheme=torch.per_tensor_affine)
  (fc2): QuantizedLinear(in_features=10, out_features=2, scale=3.220567226409912, zero_point=144, qscheme=torch.per_tensor_affine)
  (dequant): DeQuantize()
)

In [16]:
# Alternate saving and loading
torch.jit.save(torch.jit.script(model_static_quantized), "./model/quantized_test.pt")

In [17]:
quantized = torch.jit.load("./model/quantized_test.pt")

In [18]:
print(quantized)

RecursiveScriptModule(
  original_name=Net
  (quant): RecursiveScriptModule(original_name=Quantize)
  (conv1): RecursiveScriptModule(original_name=Conv2d)
  (conv2): RecursiveScriptModule(original_name=Conv2d)
  (conv3): RecursiveScriptModule(original_name=Conv2d)
  (fc1): RecursiveScriptModule(
    original_name=Linear
    (_packed_params): RecursiveScriptModule(original_name=LinearPackedParams)
  )
  (fc2): RecursiveScriptModule(
    original_name=Linear
    (_packed_params): RecursiveScriptModule(original_name=LinearPackedParams)
  )
  (dequant): RecursiveScriptModule(original_name=DeQuantize)
)
