In [1]:
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import numpy as np
import copy

from run_model import evaluate_model, train_one_epoch
from run_model import save_torchscript_model, load_torchscript_model
from datasets import tfds_data_loader, data_loader
import resmlp

import torch.optim as optim
from timm.models import create_model
from timm.data import Mixup
from timm.loss import SoftTargetCrossEntropy
from timm.utils import NativeScaler, get_state_dict, ModelEma

## Set Parameters

In [2]:
INPUT_SIZE = 224
DICT_PATH = 'E:\ResMLP_QAT\pytorch\ResMLP_S24_ReLU_99dense.pth' 

DATA_NAME = 'imagenet2012'
DATA_DIR = 'E:\datasets'

BATCH_SIZE = 32
EPOCHS = 5
LR = 1e-4

WORKERS = 0 #8

## Speed Up Configs

In [3]:
# Data Parallel Training (DPT) settings

# device = CUDA
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# set seed
seed = 336 # args.seed + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)

# check for best cudnn ops before training starts.
cudnn.benchmark = True

## Quantize

In [4]:
class QuantizedResMLP(nn.Module):
    def __init__(self, model_fp32):
        super(QuantizedResMLP, self).__init__()
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()
        self.model_fp32 = model_fp32

    def forward(self, x):
        x = self.quant(x)
        x = self.model_fp32(x)
        x = self.dequant(x)
        return x

# build train/val dataset
# create sampler (if dataset from tfds, can't apply sampler) (distributed ver. to be done)
# build up dataloader
data_loader_train, data_loader_val, NUM_CLASSES = tfds_data_loader(
    name=DATA_NAME,
    root=DATA_DIR,
    input_size=INPUT_SIZE, 
    batch_size=BATCH_SIZE,
    num_workers=WORKERS,
)

# additional data augmentation (mixup)
mixup_fn = Mixup(
    mixup_alpha=0.8, cutmix_alpha=1.0, cutmix_minmax=None,
    prob=1.0, switch_prob=0.5, mode='batch',
    label_smoothing=0.1, num_classes=NUM_CLASSES)  

# create model
model = create_model('resmlp_24', num_classes=NUM_CLASSES).cuda()
model.load_state_dict(torch.load(DICT_PATH), strict=False)

# fuse
fused_model = copy.deepcopy(model)
for basic_block_name, basic_block in fused_model.blocks.named_children():
  for sub_block_name, sub_block in basic_block.named_children():
    if sub_block_name == "mlp":
      torch.quantization.fuse_modules(
        sub_block, [['fc1', 'act']],
        inplace=True)

# apply quant/dequant stabs
quantized_model = QuantizedResMLP(model_fp32=fused_model)
quantized_model.train()

# quantization configurations
quantized_model.qconfig = torch.quantization.get_default_qat_qconfig('qnnpack')
print(quantized_model.qconfig)



QConfig(activation=functools.partial(<class 'torch.ao.quantization.fake_quantize.FusedMovingAvgObsFakeQuantize'>, observer=<class 'torch.ao.quantization.observer.MovingAverageMinMaxObserver'>, quant_min=0, quant_max=255, reduce_range=False){}, weight=functools.partial(<class 'torch.ao.quantization.fake_quantize.FusedMovingAvgObsFakeQuantize'>, observer=<class 'torch.ao.quantization.observer.MovingAverageMinMaxObserver'>, quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric){})


In [5]:
# train
print("Training QAT Model...")
quantized_model.train()
torch.quantization.prepare_qat(quantized_model, inplace=True)

#criterion = nn.CrossEntropyLoss()
criterion = SoftTargetCrossEntropy()
optimizer = optim.SGD(quantized_model.parameters(),
                        lr=LR,
                        momentum=0.9,
                        weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                    milestones=[100, 150],
                                                    gamma=0.1,
                                                    last_epoch=-1)                                                  
                                                    
train_one_epoch(model=quantized_model, criterion=criterion,
                  data_loader=data_loader_train, optimizer=optimizer,
                  device=device, epoch=1, max_norm=None,
                  model_ema=None, mixup_fn=mixup_fn)
#training...

Training QAT Model...


  return torch.fused_moving_avg_obs_fake_quant(
i: 1080 Eval Loss: 2.464:   3%|▎         | 1087/40036 [09:01<5:43:58,  1.89it/s]

In [7]:
# convert weight to int8, replace model to quantized ver.
quantized_model.cpu()
torch.quantization.convert(quantized_model, inplace=True)
quantized_model.eval()

input_fp32 = torch.randn((1, 3, 224, 224), dtype=torch.float32, device="cpu")
quantized_model(input_fp32)

SAVE_PATH = 'modeltest.pth'
save_torchscript_model(model=quantized_model, 
                        model_dir='qat_weights', 
                        model_filename='qat_Test0.pth')

In [8]:
criterion = nn.CrossEntropyLoss().cuda()
#model = create_model('resmlp_24', num_classes=NUM_CLASSES).cuda()
quantized_model = load_torchscript_model(model_filepath='qat_weights/qat_Test0.pth', device="cpu")

#model.load_state_dict(torch.load(DICT_PATH), strict=False)
# Evaluation
quantized_model.eval()
eval_loss, top1_acc, top5_acc = evaluate_model(model=quantized_model,
                                                test_loader=data_loader_val,
                                                device="cpu",
                                                criterion=criterion)
print("Epoch: {:02d} Eval Loss: {:.3f} Top1: {:.3f} Top5: {:.3f}".format(
    -1, eval_loss, top1_acc, top5_acc))

100%|██████████| 1042/1042 [29:09<00:00,  1.68s/it]

Epoch: -1 Eval Loss: 1.842 Top1: 74.376 Top5: 90.640





In [None]:
from importlib import reload

import run_model
reload(run_model)
from run_model import evaluate_model, train_one_epoch
from run_model import save_torchscript_model, load_torchscript_model