# Quantization Aware Training Sample Code

In [1]:
import os
import random

import torch
import torch.nn as nn
import torchvision

import time
import copy
import numpy as np
from torchvision import transforms
from tqdm import tqdm

torch.backends.cudnn.benchmark = True

def set_random_seeds(random_seed=0):

    torch.manual_seed(random_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(random_seed)
    random.seed(random_seed)

def memory_check():
    print(f"  Allocated: {round(torch.cuda.memory_allocated()/1024**3,2)} GB")
    print(f"  Cached:    {round(torch.cuda.memory_cached()/1024**3,2)} GB\n")

print(f"torch = {torch.__version__}")
print(f"torchvision = {torchvision.__version__}")


torch = 1.13.0a0+08820cb
torchvision = 0.14.0a0


  from .autonotebook import tqdm as notebook_tqdm


## Make ImageNet(validation 6G) Data Loader 

In [2]:
import wget
if not os.path.exists("./data/ImageNet/meta.bin"):
    print("Meta data download")
    wget.download(url="https://image-net.org/data/ILSVRC/2012/ILSVRC2012_devkit_t12.tar.gz", out="./data/ImageNet")
# if not os.path.exists("./data/ImageNet/ILSVRC2012_devkit_t3.tar.gz"):
#     print("Toolkit t3 Download")
#     toolkit_url = "https://image-net.org/data/ILSVRC/2012/ILSVRC2012_devkit_t3.tar.gz"
#     wget.download(url= toolkit_url,out="./data/ImageNet")
if not os.path.exists("./data/ImageNet/ILSVRC2012_img_val.tar"):
    print("Download val data")
    val_url  = 'https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar'
    wget.download(url=val_url, out="./data/ImageNet")

# if not os.path.exists("./data/ImageNet/ILSVRC2012_img_train_t3.tar"):
#     print("Download train t3 data")
#     train_url = "https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_train_t3.tar"
#     wget.download(url=train_url,out="./data/ImageNet")

In [3]:
def Make_loader(split_num = [0.007,0.003,0.99]):
    train_transform = transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
    ])

    test_transform = transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
    ])
    print(os.getcwd())
    dataset = torchvision.datasets.ImageNet(root="./data/ImageNet",split="val", transform = train_transform)
    Train_dataset, Test_dataset,_ = torch.utils.data.random_split(dataset, split_num)
    print(f"Train data set = {len(Train_dataset)}, Test = {len(Test_dataset)}")

    Train_loader = torch.utils.data.DataLoader(dataset=Train_dataset, batch_size=1, shuffle = True)
    Test_loader = torch.utils.data.DataLoader(dataset=Test_dataset, batch_size =1, shuffle = False)
    return Train_loader, Test_loader

## MobileNetV2

In [4]:
model = torchvision.models.mobilenet_v2(weights=torchvision.models.MobileNet_V2_Weights.DEFAULT)
from torchsummary import summary
summary(model,(3,224,224), device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 128, 128]             864
       BatchNorm2d-2         [-1, 32, 128, 128]              64
             ReLU6-3         [-1, 32, 128, 128]               0
            Conv2d-4         [-1, 32, 128, 128]             288
       BatchNorm2d-5         [-1, 32, 128, 128]              64
             ReLU6-6         [-1, 32, 128, 128]               0
            Conv2d-7         [-1, 16, 128, 128]             512
       BatchNorm2d-8         [-1, 16, 128, 128]              32
  InvertedResidual-9         [-1, 16, 128, 128]               0
           Conv2d-10         [-1, 96, 128, 128]           1,536
      BatchNorm2d-11         [-1, 96, 128, 128]             192
            ReLU6-12         [-1, 96, 128, 128]               0
           Conv2d-13           [-1, 96, 64, 64]             864
      BatchNorm2d-14           [-1, 96,

# Train and Evaluate Fuc

In [5]:
def Evaluating(model, test_loader, device, criterion):

    model.eval()
    model.to(device)

    running_loss = 0
    running_corrects = 0

    for inputs, labels in test_loader:

        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)

        loss = criterion(outputs, labels).item()

        # statistics
        running_loss += loss * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)

    eval_loss = running_loss / len(test_loader.dataset)
    eval_accuracy = running_corrects / len(test_loader.dataset)

    return eval_loss, eval_accuracy

In [6]:
def Training(model, train_loader, test_loader, device, optimizer, scheduler, epochs=100):
    criterion = nn.CrossEntropyLoss()
    print("Before Training")
    torch.cuda.memory_reserved()
    memory_check()
    # Training
    model.to(device)
    for epoch in range(epochs):

        running_loss = 0
        running_corrects = 0
        model.train()

        for inputs, labels in tqdm(iter(train_loader)):

            inputs = inputs.to(device)
            labels = labels.to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model(inputs)

            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)
 
            loss.backward()
            optimizer.step()
            
            # statistics
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)
  
        train_loss = running_loss / len(train_loader.dataset)
        train_accuracy = running_corrects / len(train_loader.dataset)

        # Evaluation
        model.eval()
        val_loss, val_acc = Evaluating(model,test_loader,device=device,criterion=criterion)
        print(f"--------{epoch}----------")
        print(f"Train {train_loss:.4f} Loss, {train_accuracy:.2f} Acc")
        print(f"Validation {val_loss:.4f} Loss, {val_acc:.2f} Acc")

        # Set learning rate scheduler
        scheduler.step()

    return model

## Layer fusion Check
conv, bn, relu를 하나의 layer로 만들어 각각의 layer를 읽어오는 연산을 줄이는 과정   
folding과는 다른 경량화 기법   
Fusion 된 layer는 identity로 바뀜

In [7]:
def model_eq_check(model1, model2, device, rtol=1e-03, atol=1e-06, num_tests=100, input_size=(1,3,256,256)):

    model1.to(device)
    model2.to(device)

    for _ in range(num_tests):
        x = torch.rand(size=input_size).to(device)
        y1 = model1(x).detach().cpu().numpy()
        y2 = model2(x).detach().cpu().numpy()
        # 배열이 허용 오차범위 abs(a - b) <= (atol + rtol * absolute(b)) 이내면 True
        if np.allclose(a=y1, b=y2, rtol=rtol, atol=atol, equal_nan=False) == False:
            print("Model equivalence test fail")
            return False
    print("Two models equal")
    return True

In [8]:
def time_test(model, device, input_size = (1,3,256,256),num_tests=100,):
    model.to(device)
    model.eval()

    x = torch.rand(size=input_size).to(device)

    with torch.no_grad():
        for _ in range(10):
            _ = model(x)
    torch.cuda.synchronize()

    with torch.no_grad():
        start_time = time.time()

        for _ in range(num_tests):
            _ = model(x)
            torch.cuda.synchronize()
        total_time = time.time() - start_time

    aver_time = total_time / num_tests
    return total_time, aver_time

In [9]:
class ConvBnReLUModel(nn.Module):
    def __init__(self):
        super(ConvBnReLUModel,self).__init__()
        self.conv = nn.Conv2d(3,5,3,bias=True).to(dtype=torch.float)
        self.bn = nn.BatchNorm2d(5).to(dtype=torch.float)
        self.relu = torch.nn.ReLU(inplace=True)
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()
    
    def forward(self,x):
        x = self.quant(x)
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        x = self.dequant(x)
        return x
    
model = ConvBnReLUModel().to(device=torch.device("cpu:0"))
model.eval()
print(model)
# for p in model.named_parameters():
#     print(p)
#     print()
# "fbgemm" for server , "qnnpack" for mobile 
# model.qconfig = torch.quantization.get_default_qconfig('fbgemm')

# # torch.quantization.fuse_module or myModel.fuse_model()
fuse_model = torch.ao.quantization.fuse_modules(model,[['conv','bn','relu']], inplace=False)
# fuse_model = model.fuse_model()
print(fuse_model)

print(f"-- Equal Test --")
model_eq_check(model, fuse_model, device=torch.device("cpu:0"))


print(f"-- Infer Time Test --")
ori_cpu_time,_ = time_test(model,torch.device("cpu"))
fus_cpu_time,_ = time_test(fuse_model,torch.device("cpu"))

print(f"origin model infer time {ori_cpu_time:.3f}s")
print(f"fusion model infer time {fus_cpu_time:.3f}s")



ConvBnReLUModel(
  (conv): Conv2d(3, 5, kernel_size=(3, 3), stride=(1, 1))
  (bn): BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (quant): QuantStub()
  (dequant): DeQuantStub()
)
ConvBnReLUModel(
  (conv): ConvReLU2d(
    (0): Conv2d(3, 5, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU(inplace=True)
  )
  (bn): Identity()
  (relu): Identity()
  (quant): QuantStub()
  (dequant): DeQuantStub()
)
-- Equal Test --
Two models equal
-- Infer Time Test --
origin model infer time 0.055s
fusion model infer time 0.041s


# MAIN

In [10]:
class QuantizationModel(nn.Module):
    def __init__(self, model):
        super(QuantizationModel, self).__init__()
        
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()
        self.fu_model = model

    def forward(self,x):
        # Convert tensor from float32 to int8
        x = self.quant(x)
        x = self.fu_model(x)
        # Convert tensor from int8 to float32
        x = self.dequant(x)
        return x

In [11]:
# gpu,cpu device 선언
if torch.cuda.is_available():
    gpu_device = torch.device("cuda:0")
cpu_device = torch.device("cpu:0")

set_random_seeds(42)

# model 가져오기
# MobileNetv2는 이미 Layer fusion이 되어 있다.
# model = torchvision.models.mobilenet_v2(weights=torchvision.models.MobileNet_V2_Weights.DEFAULT)
model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)

# 모델 저장하고 불러오기 jit 
torch.save(model.state_dict(), "test.pth")
# Load a pretrained model.
model.load_state_dict(torch.load("test.pth", map_location=gpu_device)) 

# Move the model to CPU since static quantization does not support CUDA currently.
model.to(cpu_device)
# ImageNet Data 
Train_loader, Test_loader = Make_loader()

/seunmul/QAT
Train data set = 350, Test = 150


In [12]:

# 모델을 CPU상태로 두고 eval로 layer fusion
model.eval()

# Layer fusion
fused_model = torch.quantization.fuse_modules(model,[["conv1","bn1","relu"]])

for module_name, module in fused_model.named_children():
    if "layer" in module_name:
        # basic_block 의 conv1, bn1, relu, conv2, bn2 를 fusion
        for basic_block_name, basic_block in module.named_children():
            torch.ao.quantization.fuse_modules(basic_block,[["conv1","bn1","relu"],["conv2","bn2"]],inplace=True)
            # basic_block안의 downsampling block의 Conv2d Batchnorm2D fusion
            for sub_block_name, sub_block in basic_block.named_children():
                if sub_block_name == "downsample":
                    torch.ao.quantization.fuse_modules(sub_block,[["0","1"]], inplace=True)
print(fused_model)

# Equal Test
print(f"Equal Test between origin and fused")
print(model_eq_check(model,fused_model,device=cpu_device))

ResNet(
  (conv1): ConvReLU2d(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
    (1): ReLU(inplace=True)
  )
  (bn1): Identity()
  (relu): Identity()
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): ConvReLU2d(
        (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
        (1): ReLU(inplace=True)
      )
      (bn1): Identity()
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): Identity()
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): Identity()
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
        (1): Identity()
      )
    )
    (1): Bottleneck(
      (conv1): ConvReLU2d(
        (0): Conv2d(256, 64, kernel_size=(1, 1), st

In [13]:
quat_model = QuantizationModel(model=fused_model)
quat_model.train()
# qconfig("fbgemm") 은 server 용 "qnnpack"은 mobile용 ["fbgemm", "x86", "qnnpack", "onednn"]
quat_model.qconfig = torch.quantization.get_default_qconfig("fbgemm")
print(quat_model.qconfig)

# QAT를 하기위해 quantization 모델 준비
quat_model = torch.quantization.prepare_qat(quat_model)
print(quat_model.fu_model)

# print('Inverted Residual Block: After preparation for QAT, note fake-quantization modules \n',quat_model.features[1].conv)

QConfig(activation=functools.partial(<class 'torch.ao.quantization.observer.HistogramObserver'>, reduce_range=True){}, weight=functools.partial(<class 'torch.ao.quantization.observer.PerChannelMinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_channel_symmetric){})




ResNet(
  (conv1): ConvReLU2d(
    3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3)
    (weight_fake_quant): PerChannelMinMaxObserver(min_val=tensor([]), max_val=tensor([]))
    (activation_post_process): HistogramObserver()
  )
  (bn1): Identity()
  (relu): Identity()
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): ConvReLU2d(
        64, 64, kernel_size=(1, 1), stride=(1, 1)
        (weight_fake_quant): PerChannelMinMaxObserver(min_val=tensor([]), max_val=tensor([]))
        (activation_post_process): HistogramObserver()
      )
      (bn1): Identity()
      (conv2): Conv2d(
        64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
        (weight_fake_quant): PerChannelMinMaxObserver(min_val=tensor([]), max_val=tensor([]))
        (activation_post_process): HistogramObserver()
      )
      (bn2): Identity()
      (conv3): Conv2d(
        64, 256, kernel_size=(1, 1), s

In [14]:
optimizer = torch.optim.SGD(quat_model.parameters(), lr=1e-05, momentum=0.9, weight_decay=1e-4)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=500)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 150], gamma=0.1, last_epoch=-1)

print("Before Training")
quat_model.eval()
val_loss, val_acc = Evaluating(quat_model,Test_loader,device=gpu_device,criterion=nn.CrossEntropyLoss())
print(f"Before Loss : {val_loss:.4f}, Before Acc : {val_acc:.1f}")
# quat_model = Training(quat_model,train_loader=Train_loader,test_loader=Test_loader,device=gpu_device,optimizer=optimizer,scheduler=scheduler,epochs=10)

# QAT가 적용된 floating point 모델을 quantized int model로 변환
quat_model.to(cpu_device)

Before Training
Before Loss : 655550998009.1733, Before Acc : 0.0


QuantizationModel(
  (quant): QuantStub(
    (activation_post_process): HistogramObserver()
  )
  (dequant): DeQuantStub()
  (fu_model): ResNet(
    (conv1): ConvReLU2d(
      3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3)
      (weight_fake_quant): PerChannelMinMaxObserver(
        min_val=tensor([-2.0376, -0.9499, -0.7120, -0.6238, -0.2866, -1.6196, -1.6922, -0.8482,
                -0.3329, -1.9136, -1.6600, -1.3270, -0.3801, -0.2831, -1.7473, -0.6589,
                -1.0246, -1.2643, -1.5803, -0.0614, -0.8400, -1.7917, -0.5158, -0.7798,
                -0.2384, -0.7317, -0.6917, -1.1549, -2.6805, -0.2687, -1.9068, -1.3519,
                -0.3332, -0.6825, -1.5521, -0.9448, -1.6199, -2.0809, -1.6681, -0.0631,
                -0.7764, -0.9503, -1.1462, -1.1072, -1.3614, -1.6682, -0.9263, -0.8795,
                -2.5278, -1.2078, -1.8009, -4.3954, -0.3088, -2.2456, -0.2964, -2.0849,
                -0.9609, -0.7241, -0.5984, -0.3645, -1.1131, -0.8273, -2.5904, -1.8564]), 