In [1]:
import torch 
import os
from torch.quantization import fuse_modules
import numpy as np

In [2]:
from model import SegNet

device = torch.device('cuda:{}'.format(0) if torch.cuda.is_available() else 'cpu')
model = SegNet(3, 1).to(device)

def load_model(weight_path):
    print('Loading model weights from {}'.format(weight_path))
    last_weight = os.path.join(weight_path, "best.pt")
    chkpt = torch.load(last_weight, map_location = device)
    model.load_state_dict(chkpt['model'])
    del chkpt



In [3]:
load_model('./weights')

Loading model weights from ./weights


In [4]:
def get_model_size(model):
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()
    
    size_all_mb = (param_size + buffer_size) / 1024**2
    print('model size: {:.3f}MB'.format(size_all_mb))


In [5]:
get_model_size(model)

model size: 640.171MB


In [6]:
quant_model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear})

In [7]:
get_model_size(quant_model)

model size: 168.511MB


In [8]:
from get_data import get_data_loaders
import torch.nn as nn
import utils
_, _ , test_loader = get_data_loaders("./data", 416, 416, device) 


def get_model_perform(model):
    iou_thresh = [0.5, 0.95, 0.05]
    evaluator = utils.Metrics(model, iou_thresh, device)
    criterion = nn.BCEWithLogitsLoss()
    with torch.no_grad():
        val_loss, acc, dice, iou, mAP = evaluator(test_loader, criterion)
        s = ('test loss : %.3f, Accuracy : %.3f  Dice : %.3f  IoU : %.3f mAP@[%.2f | %.2f | %.2f] : %.3f') % (val_loss, acc.item(), dice.item(), iou.item(), iou_thresh[0],
                                                                                            iou_thresh[2], iou_thresh[1], mAP.item())
        print(s)



  check_for_updates()


In [None]:
get_model_perform(model)

In [None]:
get_model_perform(quant_model)

In [9]:
from model import ConvBlock
def fuse_layers(module):
    for idx, block in enumerate(module.layers):
        if isinstance(block, ConvBlock):
            fuse_list = ['conv', 'bn', 'relu']  # These layers will be fused
            fuse_modules(block, fuse_list, inplace=True)


In [10]:
# Static quantization of a model consists of the following steps:

#     Fuse modules
#     Insert Quant/DeQuant Stubs
#     Prepare the fused module (insert observers before and after layers)
#     Calibrate the prepared module (pass it representative data)
#     Convert the calibrated module (replace with quantized version)

import torch
from torch import nn
import copy

backend = "fbgemm"  # running on a x86 CPU. Use "qnnpack" if running on ARM.

m = copy.deepcopy(model)
m.eval()

"""Fuse
- Inplace fusion replaces the first module in the sequence with the fused module, and the rest with identity modules
"""
# Apply fusion to all encoder and decoder layers
for encoder in m.encoder:
    fuse_layers(encoder)

for decoder in m.decoder:
    fuse_layers(decoder)
# torch.quantization.fuse_modules(m, ['conv','bn', 'relu'], inplace=True) # fuse first Conv-ReLU pair


# """Insert stubs"""
# m = nn.Sequential(torch.quantization.QuantStub(), 
#                   *m, 
#                   torch.quantization.DeQuantStub())

# """Prepare"""
# m.qconfig = torch.quantization.get_default_qconfig(backend)
# torch.quantization.prepare(m, inplace=True)

# """Calibrate
# - This example uses random data for convenience. Use representative (validation) data instead.
# """
# with torch.inference_mode():
#   for _ in range(10):
#     x = torch.rand(1,2, 28, 28)
#     m(x)
    
# """Convert"""
# torch.quantization.convert(m, inplace=True)

# """Check"""
# print(m[[1]].weight().element_size()) # 1 byte instead of 4 bytes for FP32


# ## FX GRAPH
# from torch.quantization import quantize_fx
# m = copy.deepcopy(model)
# m.eval()
# qconfig_dict = {"": torch.quantization.get_default_qconfig(backend)}
# # Prepare
# model_prepared = quantize_fx.prepare_fx(m, qconfig_dict)
# # Calibrate - Use representative (validation) data.
# with torch.inference_mode():
#   for _ in range(10):
#     x = torch.rand(1,2,28, 28)
#     model_prepared(x)
# # quantize
# model_quantized = quantize_fx.convert_fx(model_prepared)

In [13]:
m.decoder

ModuleList(
  (0): Decoder(
    (layers): ModuleList(
      (0-2): 3 x ConvTBlock(
        (convT): ConvTranspose2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU()
      )
    )
  )
  (1): Decoder(
    (layers): ModuleList(
      (0-1): 2 x ConvTBlock(
        (convT): ConvTranspose2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU()
      )
      (2): ConvTBlock(
        (convT): ConvTranspose2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU()
      )
    )
  )
  (2): Decoder(
    (layers): ModuleList(
      (0-1): 2 x ConvTBlock(
        (convT): ConvTranspose2d(256, 256, kernel_size=(3, 3), stride=(1