# Library

In [1]:
import random

import numpy as np

import torch
import torch.onnx
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

import onnx
import onnxoptimizer
import onnxruntime as ort
from onnxruntime.quantization import CalibrationDataReader, quantize_static, QuantType, QuantFormat
from onnxsim import simplify

from utils.model import ConvNet

# Config

In [2]:
CFG = {'seed': 42,
       'bs': 1,
       'imgsz': (64, 64),  # height, width
       'mean': (0.485, 0.456, 0.406),
       'std': (0.229, 0.224, 0.225),
       'n_cls': 3}

# Fix Seed

In [3]:
def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if use multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)

seed_everything(CFG['seed'])

# Model Load

In [4]:
model = ConvNet(num_classes=CFG['n_cls'])
model.load_state_dict(torch.load("./checkpoints/fine_tuned_conv_net.pth", map_location="cpu", weights_only=True))

<All keys matched successfully>

# ONNX 변환

In [5]:
img_h, img_w = CFG['imgsz']
dummy_input = torch.randn(1, 3, img_h, img_w)

torch.onnx.export(model,
                  dummy_input,
                  "./checkpoints/conv_net.onnx",
                  input_names=['input'],
                  output_names=['output'],
                  dynamic_axes={'input': {0: 'batch_size'}})  # 0번째 차원 dynamic

# INT8 양자화

In [8]:
class ImageFolderDataReader(CalibrationDataReader):
    def __init__(self, dataloader):
        self.dataloader = dataloader
        self.iterator = iter(self.dataloader)

    def get_next(self):
        try:
            inputs, _ = next(self.iterator)
            return {"input": inputs.numpy()}
        
        except StopIteration:
            return None

In [9]:
transform = transforms.Compose([
    transforms.Resize(CFG['imgsz']),
    transforms.ToTensor(),
    transforms.Normalize(mean=CFG['mean'], std=CFG['std'])
])

dataset = datasets.ImageFolder("./data/train/", transform=transform)  # 전체 데이터
calibration_loader = DataLoader(dataset, batch_size=CFG['bs'], shuffle=False)
data_reader = ImageFolderDataReader(calibration_loader)

quantize_static(model_input="./checkpoints/conv_net.onnx",
                model_output="./checkpoints/conv_net_int8.onnx",
                calibration_data_reader=data_reader,  # calibration data로 전체 데이터 사용
                quant_format=QuantFormat.QDQ,
                weight_type=QuantType.QInt8,
                activation_type=QuantType.QInt8)

