์ ์: Raghuraman Krishnamoorthi ํธ์ง: Seth Weidman, Jerry Zhang ๋ฒ์ญ: ๊นํ๊ธธ, Choi Yoonjeong
์ด ํํ ๋ฆฌ์ผ์์๋ ์ด๋ป๊ฒ ํ์ต ํ ์ ์ ์์ํ(post-training static quantization)๋ฅผ ํ๋์ง ๋ณด์ฌ์ฃผ๋ฉฐ, ๋ชจ๋ธ์ ์ ํ๋(accuracy)์ ๋์ฑ ๋์ด๊ธฐ ์ํ ๋ ๊ฐ์ง ๊ณ ๊ธ ๊ธฐ์ ์ธ ์ฑ๋๋ณ ์์ํ(per-channel quantization)์ ์์ํ ์๊ฐ ํ์ต(quantization-aware training)๋ ์ดํด๋ด ๋๋ค. ํ์ฌ ์์ํ๋ CPU๋ง ์ง์ํ๊ธฐ์, ์ด ํํ ๋ฆฌ์ผ์์๋ GPU/ CUDA๋ฅผ ์ด์ฉํ์ง ์์ต๋๋ค. ์ด ํํ ๋ฆฌ์ผ์ ๋๋ด๋ฉด PyTorch์์ ์์ํ๊ฐ ์ด๋ป๊ฒ ์๋๋ ํฅ์์ํค๋ฉด์ ๋ชจ๋ธ ์ฌ์ด์ฆ๋ฅผ ํฐ ํญ์ผ๋ก ์ค์ด๋์ง ํ์ธํ ์ ์์ต๋๋ค. ๊ฒ๋ค๊ฐ ์ฌ๊ธฐ ์ ์๊ฐ๋ ๋ช๋ช ๊ณ ๊ธ ์์ํ ๊ธฐ์ ์ ์ผ๋ง๋ ์ฝ๊ฒ ์ ์ฉํ๋์ง๋ ๋ณผ ์ ์๊ณ , ์ด๋ฐ ๊ธฐ์ ๋ค์ด ๋ค๋ฅธ ์์ํ ๊ธฐ์ ๋ค๋ณด๋ค ๋ชจ๋ธ์ ์ ํ๋์ ๋ถ์ ์ ์ธ ์ํฅ์ ๋ ๋ผ์น๋ ๊ฒ๋ ๋ณผ ์ ์์ต๋๋ค.
์ฃผ์: ๋ค๋ฅธ PyTorch ์ ์ฅ์์ ์์ฉ๊ตฌ ์ฝ๋(boilerplate code)๋ฅผ ๋ง์ด ์ฌ์ฉํฉ๋๋ค.
์๋ฅผ ๋ค์ด MobileNetV2
๋ชจ๋ธ ์ํคํ
์ฒ ์ ์, DataLoader ์ ์ ๊ฐ์ ๊ฒ๋ค์
๋๋ค.
๋ฌผ๋ก ์ด๋ฐ ์ฝ๋๋ค์ ์ฝ๋ ๊ฒ์ ์ถ์ฒํ์ง๋ง, ์์ํ ํน์ง๋ง ์๊ณ ์ถ๋ค๋ฉด
"4. ํ์ต ํ ์ ์ ์์ํ" ๋ถ๋ถ์ผ๋ก ๋์ด๊ฐ๋ ๋ฉ๋๋ค.
ํ์ํ ๊ฒ๋ค์ import ํ๋ ๊ฒ๋ถํฐ ์์ํด ๋ด
์๋ค:
import os
import sys
import time
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision
from torchvision import datasets
import torchvision.transforms as transforms
# # warnings ์ค์
import warnings
warnings.filterwarnings(
action='ignore',
category=DeprecationWarning,
module=r'.*'
)
warnings.filterwarnings(
action='default',
module=r'torch.ao.quantization'
)
# ๋ฐ๋ณต ๊ฐ๋ฅํ ๊ฒฐ๊ณผ๋ฅผ ์ํ ๋๋ค ์๋ ์ง์ ํ๊ธฐ
torch.manual_seed(191009)
์ฒ์์ผ๋ก MobileNetV2 ๋ชจ๋ธ ์ํคํ ์ฒ๋ฅผ ์ ์ํฉ๋๋ค. ์ด ๋ชจ๋ธ์ ์์ํ๋ฅผ ์ํ ๋ช ๊ฐ์ง ์ค์ํ ๋ณ๊ฒฝ์ฌํญ๋ค์ด ์์ต๋๋ค:
- ๋ง์
์
nn.quantized.FloatFunctional
์ผ๋ก ๊ต์ฒด - ์ ๊ฒฝ๋ง์ ์ฒ์๊ณผ ๋์
QuantStub
๋ฐDeQuantStub
์ฝ์ - ReLU๋ฅผ ReLU6๋ก ๊ต์ฒด
์๋ฆผ: ์ด ์ฝ๋๋ ์ฌ๊ธฐ ์์ ๊ฐ์ ธ์์ต๋๋ค.
from torch.ao.quantization import QuantStub, DeQuantStub
def _make_divisible(v, divisor, min_value=None):
"""
์ด ํจ์๋ ์๋ณธ TensorFlow ์ ์ฅ์์์ ๊ฐ์ ธ์์ต๋๋ค.
๋ชจ๋ ๊ณ์ธต์ด 8๋ก ๋๋์ด์ง๋ ์ฑ๋ ์ซ์๋ฅผ ๊ฐ์ง๊ณ ์์ต๋๋ค.
์ด๊ณณ์์ ํ์ธ ๊ฐ๋ฅํฉ๋๋ค:
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
:param v:
:param divisor:
:param min_value:
:return:
"""
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# ๋ด๋ฆผ์ 10% ๋๊ฒ ๋ด๋ ค๊ฐ์ง ์๋ ๊ฒ์ ๋ณด์ฅํฉ๋๋ค.
if new_v < 0.9 * v:
new_v += divisor
return new_v
class ConvBNReLU(nn.Sequential):
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
padding = (kernel_size - 1) // 2
super(ConvBNReLU, self).__init__(
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
nn.BatchNorm2d(out_planes, momentum=0.1),
# ReLU๋ก ๊ต์ฒด
nn.ReLU(inplace=False)
)
class InvertedResidual(nn.Module):
def __init__(self, inp, oup, stride, expand_ratio):
super(InvertedResidual, self).__init__()
self.stride = stride
assert stride in [1, 2]
hidden_dim = int(round(inp * expand_ratio))
self.use_res_connect = self.stride == 1 and inp == oup
layers = []
if expand_ratio != 1:
# pw
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
layers.extend([
# dw
ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup, momentum=0.1),
])
self.conv = nn.Sequential(*layers)
# torch.add๋ฅผ floatfunctional๋ก ๊ต์ฒด
self.skip_add = nn.quantized.FloatFunctional()
def forward(self, x):
if self.use_res_connect:
return self.skip_add.add(x, self.conv(x))
else:
return self.conv(x)
class MobileNetV2(nn.Module):
def __init__(self, num_classes=1000, width_mult=1.0, inverted_residual_setting=None, round_nearest=8):
"""
MobileNet V2 ๋ฉ์ธ ํด๋์ค
Args:
num_classes (int): ํด๋์ค ์ซ์
width_mult (float): ๋์ด multiplier - ์ด ์๋ฅผ ํตํด ๊ฐ ๊ณ์ธต์ ์ฑ๋ ๊ฐ์๋ฅผ ์กฐ์
inverted_residual_setting: ๋คํธ์ํฌ ๊ตฌ์กฐ
round_nearest (int): ๊ฐ ๊ณ์ธต์ ์ฑ๋ ์ซ๋ฅผ ์ด ์ซ์์ ๋ฐฐ์๋ก ๋ฐ์ฌ๋ฆผ
1๋ก ์ค์ ํ๋ฉด ๋ฐ์ฌ๋ฆผ ์ ์ง
"""
super(MobileNetV2, self).__init__()
block = InvertedResidual
input_channel = 32
last_channel = 1280
if inverted_residual_setting is None:
inverted_residual_setting = [
# t, c, n, s
[1, 16, 1, 1],
[6, 24, 2, 2],
[6, 32, 3, 2],
[6, 64, 4, 2],
[6, 96, 3, 1],
[6, 160, 3, 2],
[6, 320, 1, 1],
]
# ์ฌ์ฉ์๊ฐ t,c,n,s๋ฅผ ํ์ํ๋ค๋ ๊ฒ์ ์๋ค๋ ์ ์ ํ์ ์ฒซ ๋ฒ์งธ ์์๋ง ํ์ธ
if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
raise ValueError("inverted_residual_setting should be non-empty "
"or a 4-element list, got {}".format(inverted_residual_setting))
# ์ฒซ ๋ฒ์งธ ๊ณ์ธต ๋ง๋ค๊ธฐ
input_channel = _make_divisible(input_channel * width_mult, round_nearest)
self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
features = [ConvBNReLU(3, input_channel, stride=2)]
# ์ญ์ ๋ ์์ฐจ ๋ธ๋ญ(inverted residual blocks) ๋ง๋ค๊ธฐ
for t, c, n, s in inverted_residual_setting:
output_channel = _make_divisible(c * width_mult, round_nearest)
for i in range(n):
stride = s if i == 0 else 1
features.append(block(input_channel, output_channel, stride, expand_ratio=t))
input_channel = output_channel
# ๋ง์ง๋ง ๊ณ์ธต๋ค ๋ง๋ค๊ธฐ
features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1))
# nn.Sequential๋ก ๋ง๋ค๊ธฐ
self.features = nn.Sequential(*features)
self.quant = QuantStub()
self.dequant = DeQuantStub()
# ๋ถ๋ฅ๊ธฐ(classifier) ๋ง๋ค๊ธฐ
self.classifier = nn.Sequential(
nn.Dropout(0.2),
nn.Linear(self.last_channel, num_classes),
)
# ๊ฐ์ค์น ์ด๊ธฐํ
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.zeros_(m.bias)
def forward(self, x):
x = self.quant(x)
x = self.features(x)
x = x.mean([2, 3])
x = self.classifier(x)
x = self.dequant(x)
return x
# ์์ํ ์ ์ Conv+BN๊ณผ Conv+BN+Relu ๋ชจ๋ ๊ฒฐํฉ(fusion)
# ์ด ์ฐ์ฐ์ ์ซ์๋ฅผ ๋ณ๊ฒฝํ์ง ์์
def fuse_model(self, is_qat=False):
fuse_modules = torch.ao.quantization.fuse_modules_qat if is_qat else torch.ao.quantization.fuse_modules
for m in self.modules():
if type(m) == ConvBNReLU:
fuse_modules(m, ['0', '1', '2'], inplace=True)
if type(m) == InvertedResidual:
for idx in range(len(m.conv)):
if type(m.conv[idx]) == nn.Conv2d:
fuse_modules(m.conv, [str(idx), str(idx + 1)], inplace=True)
๋ค์์ผ๋ก ๋ชจ๋ธ ํ๊ฐ๋ฅผ ์ํ ํฌํผ ํจ์๋ค์ ๋ง๋ญ๋๋ค. ์ฝ๋ ๋๋ถ๋ถ์ ์ฌ๊ธฐ ์์ ๊ฐ์ ธ์์ต๋๋ค.
class AverageMeter(object):
"""ํ๊ท ๊ณผ ํ์ฌ ๊ฐ ๊ณ์ฐ ๋ฐ ์ ์ฅ"""
def __init__(self, name, fmt=':f'):
self.name = name
self.fmt = fmt
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)
def accuracy(output, target, topk=(1,)):
"""ํน์ k๊ฐ์ ์ํด top k ์์ธก์ ์ ํ๋ ๊ณ์ฐ"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def evaluate(model, criterion, data_loader, neval_batches):
model.eval()
top1 = AverageMeter('Acc@1', ':6.2f')
top5 = AverageMeter('Acc@5', ':6.2f')
cnt = 0
with torch.no_grad():
for image, target in data_loader:
output = model(image)
loss = criterion(output, target)
cnt += 1
acc1, acc5 = accuracy(output, target, topk=(1, 5))
print('.', end = '')
top1.update(acc1[0], image.size(0))
top5.update(acc5[0], image.size(0))
if cnt >= neval_batches:
return top1, top5
return top1, top5
def load_model(model_file):
model = MobileNetV2()
state_dict = torch.load(model_file)
model.load_state_dict(state_dict)
model.to('cpu')
return model
def print_size_of_model(model):
torch.save(model.state_dict(), "temp.p")
print('Size (MB):', os.path.getsize("temp.p")/1e6)
os.remove('temp.p')
๋ง์ง๋ง ์ฃผ์ ์ค์ ๋จ๊ณ๋ก์ ํ์ต๊ณผ ํ ์คํธ ๋ฐ์ดํฐ๋ฅผ ์ํ DataLoader๋ฅผ ์ ์ํฉ๋๋ค.
์ ์ฒด ImageNet Dataset์ ์ด์ฉํด์ ์ด ํํ ๋ฆฌ์ผ์ ์ฝ๋๋ฅผ ์คํ์ํค๊ธฐ ์ํด, ์ฒซ๋ฒ์งธ๋ก ImageNet Data ์ ์ง์๋ฅผ ๋ฐ๋ผ ImageNet์ ๋ค์ด๋ก๋ํฉ๋๋ค. ๋ค์ด๋ก๋ํ ํ์ผ์ ์์ถ์ 'data_path'์ ํ๋๋ค.
๋ค์ด๋ก๋๋ฐ์ ๋ฐ์ดํฐ๋ฅผ ์ฝ๊ธฐ ์ํด ์๋์ ์ ์๋ DataLoader ํจ์๋ค์ ์ฌ์ฉํฉ๋๋ค. ์ด๋ฐ ํจ์๋ค ๋๋ถ๋ถ์ ์ฌ๊ธฐ ์์ ๊ฐ์ ธ์์ต๋๋ค.
def prepare_data_loaders(data_path):
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
dataset = torchvision.datasets.ImageNet(
data_path, split="train", transform=transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]))
dataset_test = torchvision.datasets.ImageNet(
data_path, split="val", transform=transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
]))
train_sampler = torch.utils.data.RandomSampler(dataset)
test_sampler = torch.utils.data.SequentialSampler(dataset_test)
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=train_batch_size,
sampler=train_sampler)
data_loader_test = torch.utils.data.DataLoader(
dataset_test, batch_size=eval_batch_size,
sampler=test_sampler)
return data_loader, data_loader_test
๋ค์์ผ๋ก ์ฌ์ ์ ํ์ต๋ MobileNetV2์ ๋ถ๋ฌ์ต๋๋ค. ๋ชจ๋ธ์ ๋ค์ด๋ก๋ ๋ฐ์ ์ ์๋ URL์ `์ฌ๊ธฐ <<https://download.pytorch.org/models/mobilenet_v2-b0353104.pth>>`_ ์์ ์ ๊ณตํฉ๋๋ค.
data_path = '~/.data/imagenet'
saved_model_dir = 'data/'
float_model_file = 'mobilenet_pretrained_float.pth'
scripted_float_model_file = 'mobilenet_quantization_scripted.pth'
scripted_quantized_model_file = 'mobilenet_quantization_scripted_quantized.pth'
train_batch_size = 30
eval_batch_size = 50
data_loader, data_loader_test = prepare_data_loaders(data_path)
criterion = nn.CrossEntropyLoss()
float_model = load_model(saved_model_dir + float_model_file).to('cpu')
# ๋ค์์ผ๋ก "๋ชจ๋ ๊ฒฐํฉ"์ ํฉ๋๋ค. ๋ชจ๋ ๊ฒฐํฉ์ ๋ฉ๋ชจ๋ฆฌ ์ ๊ทผ์ ์ค์ฌ ๋ชจ๋ธ์ ๋น ๋ฅด๊ฒ ๋ง๋ค๋ฉด์
# ์ ํ๋ ์์น๋ฅผ ํฅ์์ํต๋๋ค. ๋ชจ๋ ๊ฒฐํฉ์ ์ด๋ ํ ๋ชจ๋ธ์๋ผ๋ ์ฌ์ฉํ ์ ์์ง๋ง,
# ์์ํ๋ ๋ชจ๋ธ์ ์ฌ์ฉํ๋ ๊ฒ์ด ํนํ๋ ๋ ์ผ๋ฐ์ ์
๋๋ค.
print('\n Inverted Residual Block: Before fusion \n\n', float_model.features[1].conv)
float_model.eval()
# ๋ชจ๋ ๊ฒฐํฉ
float_model.fuse_model()
# Conv+BN+Relu์ Conv+Relu ๊ฒฐํฉ์ ์ ์
print('\n Inverted Residual Block: After fusion\n\n',float_model.features[1].conv)
๋ง์ง๋ง์ผ๋ก "๊ธฐ์ค"์ด ๋ ์ ํ๋๋ฅผ ์ป๊ธฐ ์ํด, ๋ชจ๋ ๊ฒฐํฉ์ ์ฌ์ฉํ ์์ํ๋์ง ์์ ๋ชจ๋ธ์ ์ ํ๋๋ฅผ ๋ด ์๋ค.
num_eval_batches = 1000
print("Size of baseline model")
print_size_of_model(float_model)
top1, top5 = evaluate(float_model, criterion, data_loader_test, neval_batches=num_eval_batches)
print('Evaluation accuracy on %d images, %2.2f'%(num_eval_batches * eval_batch_size, top1.avg))
torch.jit.save(torch.jit.script(float_model), saved_model_dir + scripted_float_model_file)
์ ์ฒด ๋ชจ๋ธ์ 50,000๊ฐ์ ์ด๋ฏธ์ง๋ฅผ ๊ฐ์ง eval ๋ฐ์ดํฐ์ ์์ 71.9%์ ์ ํ๋๋ฅผ ๋ณด์ ๋๋ค.
์ด ๊ฐ์ด ๋น๊ต๋ฅผ ์ํ ๊ธฐ์ค์ด ๋ ๊ฒ์ ๋๋ค. ๋ค์์ผ๋ก ์์ํ๋ ๋ชจ๋ธ์ ๋ด ์๋ค.
ํ์ต ํ ์ ์ ์์ํ๋ ๋์ ์์ํ์ฒ๋ผ ๊ฐ์ค์น๋ฅผ float์์ int๋ก ๋ณํํ๋ ๊ฒ๋ฟ๋ง ์๋๋ผ ์ถ๊ฐ์ ์ธ ๋จ๊ณ๋ ์ํํฉ๋๋ค. ๋คํธ์ํฌ์ ๋ฐ์ดํฐ ๋ฐฐ์น์ ์ฒซ ๋ฒ์งธ ๊ณต๊ธ๊ณผ ๋ค๋ฅธ ํ์ฑ๊ฐ๋ค์ ๋ถํฌ ๊ฒฐ๊ณผ ๊ณ์ฐ์ด ์ด๋ฌํ ๋จ๊ณ์ ๋๋ค. (ํนํ ์ด๋ฌํ ์ถ๊ฐ์ ์ธ ๋จ๊ณ๋ ๊ณ์ฐํ ๊ฐ์ ๊ธฐ๋กํ๊ณ ์ถ์ ์ง์ ์ observer ๋ชจ๋์ ์ฝ์ ํฉ์ผ๋ก์จ ๋๋ฉ๋๋ค.) ์ด๋ฌํ ๋ถํฌ๋ค์ ์ถ๋ก ์์ ์ ํน์ ํ ๋ค๋ฅธ ํ์ฑ๊ฐ๋ค์ด ์ด๋ป๊ฒ ์์ํ๋์ด์ผ ํ๋์ง ๊ฒฐ์ ํ๋๋ฐ ์ฌ์ฉ๋ฉ๋๋ค. (๊ฐ๋จํ ๋ฐฉ๋ฒ์ผ๋ก๋ ๋จ์ํ ํ์ฑ๊ฐ๋ค์ ์ ์ฒด ๋ฒ์๋ฅผ 256๊ฐ์ ๋จ๊ณ๋ก ๋๋๋ ๊ฒ์ด์ง๋ง, ์ข ๋ ๋ณต์กํ ๋ฐฉ๋ฒ๋ ์ ๊ณตํฉ๋๋ค.) ํนํ, ์ด๋ฌํ ์ถ๊ฐ์ ์ธ ๋จ๊ณ๋ ๊ฐ ์ฐ์ฐ ์ฌ์ด์ฌ์ด์ ์์ํ๋ ๊ฐ์ float์ผ๋ก ๋ณํ - ๋ฐ int๋ก ๋๋๋ฆผ - ํ๋ ๊ฒ๋ฟ๋ง ์๋๋ผ ์์ํ๋ ๊ฐ์ ๋ชจ๋ ์ฐ์ฐ๋ค๋ผ๋ฆฌ ์ฃผ๊ณ ๋ฐ๋ ๊ฒ๋ ๊ฐ๋ฅํ๊ฒ ํ์ฌ ์์ฒญ๋ ์๋ ํฅ์์ด ๋ฉ๋๋ค.
num_calibration_batches = 32
myModel = load_model(saved_model_dir + float_model_file).to('cpu')
myModel.eval()
# Conv, bn๊ณผ relu ๊ฒฐํฉ
myModel.fuse_model()
# ์์ํ ์ค์ ๋ช
์
# ๊ฐ๋จํ min/max ๋ฒ์ ์ถ์ ๋ฐ ํ
์๋ณ ๊ฐ์ค์น ์์ํ๋ก ์์
myModel.qconfig = torch.ao.quantization.default_qconfig
print(myModel.qconfig)
torch.ao.quantization.prepare(myModel, inplace=True)
# ์ฒซ ๋ฒ์งธ ๋ณด์ (calibrate)
print('Post Training Quantization Prepare: Inserting Observers')
print('\n Inverted Residual Block:After observer insertion \n\n', myModel.features[1].conv)
# ํ์ต ๋ฐ์ดํฐ์
์ผ๋ก ๋ณด์ (calibrate)
evaluate(myModel, criterion, data_loader, neval_batches=num_calibration_batches)
print('Post Training Quantization: Calibration done')
# ์์ํ๋ ๋ชจ๋ธ๋ก ๋ณํ
torch.ao.quantization.convert(myModel, inplace=True)
# ๋ชจ๋ธ์ ๋ณด์ ํด์ผ ํ๋ค(calibrate the model)๋ ์ฌ์ฉ์ ๊ฒฝ๊ณ (user warning)๊ฐ ํ์๋ ์ ์์ง๋ง ๋ฌด์ํด๋ ๋ฉ๋๋ค.
# ์ด ๊ฒฝ๊ณ ๋ ๊ฐ ๋ชจ๋ธ ์คํ ์ ๋ชจ๋ ๋ชจ๋์ด ์คํ๋๋ ๊ฒ์ด ์๋๊ธฐ ๋๋ฌธ์ ์ผ๋ถ ๋ชจ๋์ด ๋ณด์ ๋์ง ์์ ์
# ์๋ค๋ ๊ฒฝ๊ณ ์
๋๋ค.
print('Post Training Quantization: Convert done')
print('\n Inverted Residual Block: After fusion and quantization, note fused modules: \n\n',myModel.features[1].conv)
print("Size of model after quantization")
print_size_of_model(myModel)
top1, top5 = evaluate(myModel, criterion, data_loader_test, neval_batches=num_eval_batches)
print('Evaluation accuracy on %d images, %2.2f'%(num_eval_batches * eval_batch_size, top1.avg))
์์ํ๋ ๋ชจ๋ธ์ eval ๋ฐ์ดํฐ์ ์์ 56.7%์ ์ ํ๋๋ฅผ ๋ณด์ฌ์ค๋๋ค. ์ด๋ ์์ํ ํ๋ผ๋ฏธํฐ๋ฅผ ๊ฒฐ์ ํ๊ธฐ ์ํด ๋จ์ min/max Observer๋ฅผ ์ฌ์ฉํ๊ธฐ ๋๋ฌธ์ ๋๋ค. ๊ทธ๋ผ์๋ ๋ถ๊ตฌํ๊ณ ๋ชจ๋ธ์ ํฌ๊ธฐ๋ฅผ 3.6 MB ๋ฐ์ผ๋ก ์ค์์ต๋๋ค. ์ด๋ ๊ฑฐ์ 4๋ถ์ 1 ๋ก ์ค์ด๋ ํฌ๊ธฐ์ ๋๋ค.
์ด์ ๋ํด ๋จ์ํ ๋ค๋ฅธ ์์ํ ์ค์ ์ ์ฌ์ฉํ๊ธฐ๋ง ํด๋ ์ ํ๋๋ฅผ ํฐ ํญ์ผ๋ก ํฅ์์ํฌ ์ ์์ต๋๋ค. x86 ์ํคํ ์ฒ์์ ์์ํ๋ฅผ ์ํ ๊ถ์ฅ ์ค์ ์ ๊ทธ๋๋ก ์ฐ๊ธฐ๋ง ํด๋ ๋ฉ๋๋ค. ์ด๋ฌํ ์ค์ ์ ์๋์ ๊ฐ์ต๋๋ค:
- ์ฑ๋๋ณ ๊ธฐ๋ณธ ๊ฐ์ค์น ์์ํ
- ํ์ฑ๊ฐ์ ์์งํด์ ์ต์ ํ๋ ์์ํ ํ๋ผ๋ฏธํฐ๋ฅผ ๊ณ ๋ฅด๋ ํ์คํ ๊ทธ๋จ Observer ์ฌ์ฉ
per_channel_quantized_model = load_model(saved_model_dir + float_model_file)
per_channel_quantized_model.eval()
per_channel_quantized_model.fuse_model()
# ์ด์ ์ 'fbgemm' ๋ํ ์ฌ์ ํ ์ฌ์ฉ ๊ฐ๋ฅํ์ง๋ง, 'x86'์ ๊ธฐ๋ณธ์ผ๋ก ์ฌ์ฉํ๋ ๊ฒ์ ๊ถ์ฅํฉ๋๋ค.
per_channel_quantized_model.qconfig = torch.ao.quantization.get_default_qconfig('x86')
print(per_channel_quantized_model.qconfig)
torch.ao.quantization.prepare(per_channel_quantized_model, inplace=True)
evaluate(per_channel_quantized_model,criterion, data_loader, num_calibration_batches)
torch.ao.quantization.convert(per_channel_quantized_model, inplace=True)
top1, top5 = evaluate(per_channel_quantized_model, criterion, data_loader_test, neval_batches=num_eval_batches)
print('Evaluation accuracy on %d images, %2.2f'%(num_eval_batches * eval_batch_size, top1.avg))
torch.jit.save(torch.jit.script(per_channel_quantized_model), saved_model_dir + scripted_quantized_model_file)
๋จ์ํ ์์ํ ์ค์ ๋ฐฉ๋ฒ์ ๋ณ๊ฒฝํ๋ ๊ฒ๋ง์ผ๋ก๋ ์ ํ๋๊ฐ 67.3%๋ฅผ ๋์ ์ ๋๋ก ํฅ์์ด ๋์์ต๋๋ค! ๊ทธ๋ผ์๋ ์ด ์์น๋ ์์์ ๊ตฌํ ๊ธฐ์ค๊ฐ 71.9%์์ 4ํผ์ผํธ๋ ๋ฎ์ ์์น์ ๋๋ค. ์ด์ ์์ํ ์๊ฐ ํ์ต์ ์๋ํด ๋ด ์๋ค.
์์ํ ์๊ฐ ํ์ต(QAT)์ ์ผ๋ฐ์ ์ผ๋ก ๊ฐ์ฅ ๋์ ์ ํ๋๋ฅผ ์ ๊ณตํ๋ ์์ํ ๋ฐฉ๋ฒ์ ๋๋ค. ๋ชจ๋ ๊ฐ์ค์นํ ํ์ฑ๊ฐ์ QAT๋ก ์ธํด ํ์ต ๋์ค์ ์์ ํ์ ์ญ์ ํ๋ฅผ ๋์ค "๊ฐ์ง ์์ํ"๋ฉ๋๋ค. ์ด๋ float๊ฐ์ด int8 ๊ฐ์ผ๋ก ๋ฐ์ฌ๋ฆผํ๋ ๊ฒ์ฒ๋ผ ํ๋ด๋ฅผ ๋ด์ง๋ง, ๋ชจ๋ ๊ณ์ฐ์ ์ฌ์ ํ ๋ถ๋์์์ ์ซ์๋ก ๊ณ์ฐ์ ํฉ๋๋ค. ๊ทธ๋์ ๊ฒฐ๊ตญ ํ๋ จ ๋์์ ๋ชจ๋ ๊ฐ์ค์น ์กฐ์ ์ ๋ชจ๋ธ์ด ์์ํ๋ ๊ฒ์ด๋ผ๋ ์ฌ์ค์ "์๊ฐ"ํ ์ฑ๋ก ์ด๋ฃจ์ด์ง๊ฒ ๋ฉ๋๋ค. ๊ทธ๋์ QAT๋ ์์ํ๊ฐ ์ด๋ฃจ์ด์ง๊ณ ๋๋ฉด ๋์ ์์ํ๋ ํ์ต ์ ์ ์ ์์ํ๋ณด๋ค ๋์ฒด๋ก ๋ ๋์ ์ ํ๋๋ฅผ ๋ณด์ฌ์ค๋๋ค.
์ค์ ๋ก QAT๊ฐ ์ด๋ฃจ์ด์ง๋ ์ ์ฒด ํ๋ฆ์ ์ด์ ๊ณผ ๋งค์ฐ ์ ์ฌํฉ๋๋ค:
- ์ด์ ๊ณผ ๊ฐ์ ๋ชจ๋ธ์ ์ฌ์ฉํ ์ ์์ต๋๋ค. ์์ํ ์๊ฐ ํ์ต์ ์ํ ์ถ๊ฐ์ ์ธ ์ค๋น๋ ํ์ ์์ต๋๋ค.
- ๊ฐ์ค์น์ ํ์ฑ๊ฐ ๋ค์ ์ด๋ค ์ข
๋ฅ์ ๊ฐ์ง ์์ํ๋ฅผ ์ฌ์ฉํ ๊ฒ์ธ์ง ๋ช
์ํ๋
qconfig
์ ์ฌ์ฉ์ด ํ์ํฉ๋๋ค. Observer๋ฅผ ๋ช ์ํ๋ ๊ฒ ๋์ ์ ๋ง์ด์ฃ .
๋จผ์ ํ์ต ํจ์๋ถํฐ ์ ์ํฉ๋๋ค:
def train_one_epoch(model, criterion, optimizer, data_loader, device, ntrain_batches):
model.train()
top1 = AverageMeter('Acc@1', ':6.2f')
top5 = AverageMeter('Acc@5', ':6.2f')
avgloss = AverageMeter('Loss', '1.5f')
cnt = 0
for image, target in data_loader:
start_time = time.time()
print('.', end = '')
cnt += 1
image, target = image.to(device), target.to(device)
output = model(image)
loss = criterion(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
acc1, acc5 = accuracy(output, target, topk=(1, 5))
top1.update(acc1[0], image.size(0))
top5.update(acc5[0], image.size(0))
avgloss.update(loss, image.size(0))
if cnt >= ntrain_batches:
print('Loss', avgloss.avg)
print('Training: * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
.format(top1=top1, top5=top5))
return
print('Full imagenet train set: * Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f}'
.format(top1=top1, top5=top5))
return
์ด์ ์ฒ๋ผ ๋ชจ๋์ ๊ฒฐํฉํฉ๋๋ค.
qat_model = load_model(saved_model_dir + float_model_file)
qat_model.fuse_model(is_qat=True)
optimizer = torch.optim.SGD(qat_model.parameters(), lr = 0.0001)
# ์ด์ ์ 'fbgemm' ๋ํ ์ฌ์ ํ ์ฌ์ฉ ๊ฐ๋ฅํ์ง๋ง, 'x86'์ ๊ธฐ๋ณธ์ผ๋ก ์ฌ์ฉํ๋ ๊ฒ์ ๊ถ์ฅํฉ๋๋ค.
qat_model.qconfig = torch.ao.quantization.get_default_qat_qconfig('x86')
๋ง์ง๋ง์ผ๋ก ๋ชจ๋ธ์ด ์์ํ ์๊ฐ ํ์ต์ ์ค๋นํ๊ธฐ ์ํด prepare_qat
๋ก "๊ฐ์ง ์์ํ"๋ฅผ ์ํํฉ๋๋ค.
torch.ao.quantization.prepare_qat(qat_model, inplace=True)
print('Inverted Residual Block: After preparation for QAT, note fake-quantization modules \n',qat_model.features[1].conv)
๋์ ์ ํ๋์ ์์ํ๋ ๋ชจ๋ธ์ ํ์ต์ํค๊ธฐ ์ํด์๋ ์ถ๋ก ์์ ์์ ์ ํํ ์ซ์ ๋ชจ๋ธ๋ง์ ํ์๋ก ํฉ๋๋ค. ๊ทธ๋์ ์์ํ ์๊ฐ ํ์ต์์๋ ํ์ต ๋ฃจํ๋ฅผ ์ด๋ ๊ฒ ๋ณ๊ฒฝํฉ๋๋ค:
- ์ถ๋ก ์์น์ ๋ ์ ์ผ์นํ๋๋ก ํ์ต์ด ๋๋ ๋ ๋ฐฐ์น ์ ๊ทํ๋ฅผ ์ด๋ ํ๊ท ๊ณผ ๋ถ์ฐ์ ์ฌ์ฉํ๋ ๊ฒ์ผ๋ก ๋ณ๊ฒฝํฉ๋๋ค.
- ์์ํ ํ๋ผ๋ฏธํฐ(ํฌ๊ธฐ์ ์์ )๋ฅผ ๊ณ ์ ํ๊ณ ๊ฐ์ค์น๋ฅผ ๋ฏธ์ธ ์กฐ์ (fine tune)ํฉ๋๋ค.
num_train_batches = 20
# QAT๋ ์๊ฐ์ด ๊ฑธ๋ฆฌ๋ ์์
์ด๋ฉฐ ๋ช ์ํญ์ ๊ฑธ์ณ ํ๋ จ์ด ํ์ํฉ๋๋ค.
# ํ์ต ๋ฐ ๊ฐ ์ํญ ์ดํ ์ ํ๋ ํ์ธ
for nepoch in range(8):
train_one_epoch(qat_model, criterion, optimizer, data_loader, torch.device('cpu'), num_train_batches)
if nepoch > 3:
# ์์ํ ํ๋ผ๋ฏธํฐ ๊ณ ์
qat_model.apply(torch.ao.quantization.disable_observer)
if nepoch > 2:
# ๋ฐฐ์น ์ ๊ทํ ํ๊ท ๋ฐ ๋ถ์ฐ ์ถ์ ๊ฐ ๊ณ ์
qat_model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
# ๊ฐ ์ํญ ์ดํ ์ ํ๋ ํ์ธ
quantized_model = torch.ao.quantization.convert(qat_model.eval(), inplace=False)
quantized_model.eval()
top1, top5 = evaluate(quantized_model,criterion, data_loader_test, neval_batches=num_eval_batches)
print('Epoch %d :Evaluation accuracy on %d images, %2.2f'%(nepoch, num_eval_batches * eval_batch_size, top1.avg))
์์ํ ์๊ฐ ํ์ต์ ์ ์ฒด ImageNet ๋ฐ์ดํฐ์ ์์ 71.5%์ ์ ํ๋๋ฅผ ๋ํ๋ ๋๋ค. ์ด ๊ฐ์ ๊ธฐ์ค๊ฐ 71.9%์ ์์์ ์์ค์ผ๋ก ๊ทผ์ ํ ์์น์ ๋๋ค.
์์ํ ์๊ฐ ํ์ต์ ๋ํ ๋ ๋ง์ ๊ฒ๋ค:
- QAT๋ ๋ ๋ง์ ๋๋ฒ๊น ์ ๊ฐ๋ฅํ๊ฒ ํ๋ ํ์ต ํ ์์ํ ๊ธฐ์ ์ ์์ ์งํฉ์ ๋๋ค. ์๋ฅผ ๋ค์ด ๋ชจ๋ธ์ ์ ํ๋๊ฐ ๊ฐ์ค์น๋ ํ์ฑ ์์ํ๋ก ์ธํด ์ ํ์ ๋ฐ์ ๋ ๋์์ง ์ ์๋ ์ํฉ์ธ์ง ๋ถ์ํ ์ ์์ต๋๋ค.
- ๋ถ๋์์์ ์ ์ฌ์ฉํ ์์ํ๋ ๋ชจ๋ธ์ ์๋ฎฌ๋ ์ด์ ํ ์๋ ์์ต๋๋ค. ์ค์ ์์ํ๋ ์ฐ์ฐ์ ์์น๋ฅผ ๋ชจ๋ธ๋งํ๊ธฐ ์ํด ๊ฐ์ง ์์ํ๋ฅผ ์ด์ฉํ๊ณ ์๊ธฐ ๋๋ฌธ์ ๋๋ค.
- ํ์ต ํ ์์ํ ๋ํ ์ฝ๊ฒ ํ๋ด๋ผ ์ ์์ต๋๋ค.
๋ง์ง๋ง์ผ๋ก ์์์ ์ธ๊ธํ ๊ฒ๋ค์ ํ์ธํด ๋ด ์๋ค. ์์ํ๋ ๋ชจ๋ธ์ด ์ค์ ๋ก ์ถ๋ก ๋ ๋ ๋น ๋ฅด๊ฒ ํ๋ ๊ฑธ๊น์? ์ํํด ๋ด ์๋ค:
def run_benchmark(model_file, img_loader):
elapsed = 0
model = torch.jit.load(model_file)
model.eval()
num_batches = 5
# ์ด๋ฏธ์ง ๋ฐฐ์น๋ค ์ด์ฉํ์ฌ ์คํฌ๋ฆฝํธ๋ ๋ชจ๋ธ ์คํ
for i, (images, target) in enumerate(img_loader):
if i < num_batches:
start = time.time()
output = model(images)
end = time.time()
elapsed = elapsed + (end-start)
else:
break
num_images = images.size()[0] * num_batches
print('Elapsed time: %3.0f ms' % (elapsed/num_images*1000))
return elapsed
run_benchmark(saved_model_dir + scripted_float_model_file, data_loader_test)
run_benchmark(saved_model_dir + scripted_quantized_model_file, data_loader_test)
๋งฅ๋ถ ํ๋ก์ ๋ก์ปฌ ํ๊ฒฝ์์ ์ผ๋ฐ์ ์ธ ๋ชจ๋ธ ์คํ์ 61ms, ์์ํ๋ ๋ชจ๋ธ ์คํ์ 20ms๊ฐ ๊ฑธ๋ ธ์ต๋๋ค. ์ด๋ฌํ ๊ฒฐ๊ณผ๋ ๋ถ๋์์์ ๋ชจ๋ธ๊ณผ ์์ํ๋ ๋ชจ๋ธ์ ๋น๊ตํ์ ๋, ์์ํ๋ ๋ชจ๋ธ์์ ์ผ๋ฐ์ ์ผ๋ก 2-4x ์๋ ํฅ์์ด ์ด๋ฃจ์ด์ง ๊ฒ์ ๋ณด์ฌ์ค๋๋ค.
์ด ํํ ๋ฆฌ์ผ์์ ํ์ต ํ ์ ์ ์์ํ์ ์์ํ ์๊ฐ ํ์ต์ด๋ผ๋ ๋ ๊ฐ์ง ์์ํ ๋ฐฉ๋ฒ์ ์ดํด๋ดค์ต๋๋ค. ์ด ์์ํ ๋ฐฉ๋ฒ๋ค์ด "๋ด๋ถ์ ์ผ๋ก" ์ด๋ป๊ฒ ๋์์ ํ๋์ง์ PyTorch์์ ์ด๋ป๊ฒ ์ฌ์ฉํ ์ ์๋์ง๋ ๋ณด์์ต๋๋ค.
์ฝ์ด์ฃผ์ ์ ๊ฐ์ฌํฉ๋๋ค. ์ธ์ ๋์ฒ๋ผ ์ด๋ ํ ํผ๋๋ฐฑ๋ ํ์์ด๋, ์๊ฒฌ์ด ์๋ค๋ฉด ์ฌ๊ธฐ ์ ์ด์๋ฅผ ๋จ๊ฒจ ์ฃผ์ธ์.