In [1]:
# Import modules
import os
import cv2
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
model_path = os.getcwd() + '/model/best_model.pth'
NUM_CLASSES = 2

In [3]:
class Net(nn.Module):
    
    def __init__(self):
        super().__init__() 
        self.quant = torch.quantization.QuantStub()
        self.conv1 = nn.Conv2d(3, 32, 5) 
        self.conv2 = nn.Conv2d(32, 64, 3)
        self.conv3 = nn.Conv2d(64, 128, 3)
        
        x = torch.randn(3,224,224).view(-1,3,224,224)
        self._to_linear = None
        self.convs(x)

        self.fc1 = nn.Linear(self._to_linear, 10)
        self.fc2 = nn.Linear(10, NUM_CLASSES)
        self.dequant = torch.quantization.DeQuantStub()

    def convs(self, x):
        # max pooling over 2x2
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv3(x)), (2, 2))
        
        if self._to_linear is None:
            self._to_linear = x[0].shape[0]*x[0].shape[1]*x[0].shape[2]
        return x

    def forward(self, x):
        x = self.quant(x)
        x = self.convs(x)
        x = x.reshape(-1, self._to_linear) 
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        x = self.dequant(x)
        if NUM_CLASSES == 2:
            return F.sigmoid(x)
        else:
            return F.softmax(x, dim=1)

In [4]:
model = Net()
state_dict = torch.load(model_path)
model.load_state_dict(state_dict)
model.eval()

Net(
  (quant): QuantStub()
  (conv1): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (conv3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
  (fc1): Linear(in_features=86528, out_features=10, bias=True)
  (fc2): Linear(in_features=10, out_features=2, bias=True)
  (dequant): DeQuantStub()
)

In [5]:
# Original model size
print("%.2f MB" %(os.path.getsize(model_path)/1e6))

3.84 MB


In [12]:
model_dynamic_quantized = torch.quantization.quantize_dynamic(
    model, qconfig_spec={torch.nn.Linear, torch.nn.Conv2d, torch.nn.functional.max_pool2d}, dtype=torch.qint8
)

In [13]:
def print_model_size(mdl):
    torch.save(mdl.state_dict(), os.path.join(os.getcwd(), "model", "quant_model_dynamic.pth"))
    print("%.2f MB" %(os.path.getsize(os.path.join(os.getcwd(), "model", "quant_model_dynamic.pth"))/1e6))
    #os.remove('tmp.pt')

In [14]:
print_model_size(model_dynamic_quantized) 

1.25 MB


In [15]:
print(model_dynamic_quantized)

Net(
  (quant): QuantStub()
  (conv1): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (conv3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
  (fc1): DynamicQuantizedLinear(in_features=86528, out_features=10, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
  (fc2): DynamicQuantizedLinear(in_features=10, out_features=2, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
  (dequant): DeQuantStub()
)
