Skip to content

Files

Latest commit

 

History

History
634 lines (498 loc) ยท 28.1 KB

static_quantization_tutorial.rst

File metadata and controls

634 lines (498 loc) ยท 28.1 KB

(๋ฒ ํƒ€) PyTorch์—์„œ Eager Mode๋ฅผ ์ด์šฉํ•œ ์ •์  ์–‘์žํ™”

์ €์ž: Raghuraman Krishnamoorthi ํŽธ์ง‘: Seth Weidman, Jerry Zhang ๋ฒˆ์—ญ: ๊น€ํ˜„๊ธธ, Choi Yoonjeong

์ด ํŠœํ† ๋ฆฌ์–ผ์—์„œ๋Š” ์–ด๋–ป๊ฒŒ ํ•™์Šต ํ›„ ์ •์  ์–‘์žํ™”(post-training static quantization)๋ฅผ ํ•˜๋Š”์ง€ ๋ณด์—ฌ์ฃผ๋ฉฐ, ๋ชจ๋ธ์˜ ์ •ํ™•๋„(accuracy)์„ ๋”์šฑ ๋†’์ด๊ธฐ ์œ„ํ•œ ๋‘ ๊ฐ€์ง€ ๊ณ ๊ธ‰ ๊ธฐ์ˆ ์ธ ์ฑ„๋„๋ณ„ ์–‘์žํ™”(per-channel quantization)์™€ ์–‘์žํ™” ์ž๊ฐ ํ•™์Šต(quantization-aware training)๋„ ์‚ดํŽด๋ด…๋‹ˆ๋‹ค. ํ˜„์žฌ ์–‘์žํ™”๋Š” CPU๋งŒ ์ง€์›ํ•˜๊ธฐ์—, ์ด ํŠœํ† ๋ฆฌ์–ผ์—์„œ๋Š” GPU/ CUDA๋ฅผ ์ด์šฉํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค. ์ด ํŠœํ† ๋ฆฌ์–ผ์„ ๋๋‚ด๋ฉด PyTorch์—์„œ ์–‘์žํ™”๊ฐ€ ์–ด๋–ป๊ฒŒ ์†๋„๋Š” ํ–ฅ์ƒ์‹œํ‚ค๋ฉด์„œ ๋ชจ๋ธ ์‚ฌ์ด์ฆˆ๋ฅผ ํฐ ํญ์œผ๋กœ ์ค„์ด๋Š”์ง€ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๊ฒŒ๋‹ค๊ฐ€ ์—ฌ๊ธฐ ์— ์†Œ๊ฐœ๋œ ๋ช‡๋ช‡ ๊ณ ๊ธ‰ ์–‘์žํ™” ๊ธฐ์ˆ ์„ ์–ผ๋งˆ๋‚˜ ์‰ฝ๊ฒŒ ์ ์šฉํ•˜๋Š”์ง€๋„ ๋ณผ ์ˆ˜ ์žˆ๊ณ , ์ด๋Ÿฐ ๊ธฐ์ˆ ๋“ค์ด ๋‹ค๋ฅธ ์–‘์žํ™” ๊ธฐ์ˆ ๋“ค๋ณด๋‹ค ๋ชจ๋ธ์˜ ์ •ํ™•๋„์— ๋ถ€์ •์ ์ธ ์˜ํ–ฅ์„ ๋œ ๋ผ์น˜๋Š” ๊ฒƒ๋„ ๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์ฃผ์˜: ๋‹ค๋ฅธ PyTorch ์ €์žฅ์†Œ์˜ ์ƒ์šฉ๊ตฌ ์ฝ”๋“œ(boilerplate code)๋ฅผ ๋งŽ์ด ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด MobileNetV2 ๋ชจ๋ธ ์•„ํ‚คํ…์ฒ˜ ์ •์˜, DataLoader ์ •์˜ ๊ฐ™์€ ๊ฒƒ๋“ค์ž…๋‹ˆ๋‹ค. ๋ฌผ๋ก  ์ด๋Ÿฐ ์ฝ”๋“œ๋“ค์„ ์ฝ๋Š” ๊ฒƒ์„ ์ถ”์ฒœํ•˜์ง€๋งŒ, ์–‘์žํ™” ํŠน์ง•๋งŒ ์•Œ๊ณ  ์‹ถ๋‹ค๋ฉด "4. ํ•™์Šต ํ›„ ์ •์  ์–‘์žํ™”" ๋ถ€๋ถ„์œผ๋กœ ๋„˜์–ด๊ฐ€๋„ ๋ฉ๋‹ˆ๋‹ค. ํ•„์š”ํ•œ ๊ฒƒ๋“ค์„ import ํ•˜๋Š” ๊ฒƒ๋ถ€ํ„ฐ ์‹œ์ž‘ํ•ด ๋ด…์‹œ๋‹ค:

import os
import sys
import time
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import torchvision
from torchvision import datasets
import torchvision.transforms as transforms

# # warnings ์„ค์ •
import warnings
warnings.filterwarnings(
    action='ignore',
    category=DeprecationWarning,
    module=r'.*'
)
warnings.filterwarnings(
    action='default',
    module=r'torch.ao.quantization'
)

# ๋ฐ˜๋ณต ๊ฐ€๋Šฅํ•œ ๊ฒฐ๊ณผ๋ฅผ ์œ„ํ•œ ๋žœ๋ค ์‹œ๋“œ ์ง€์ •ํ•˜๊ธฐ
torch.manual_seed(191009)

1. ๋ชจ๋ธ ์•„ํ‚คํ…์ฒ˜

์ฒ˜์Œ์œผ๋กœ MobileNetV2 ๋ชจ๋ธ ์•„ํ‚คํ…์ฒ˜๋ฅผ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค. ์ด ๋ชจ๋ธ์€ ์–‘์žํ™”๋ฅผ ์œ„ํ•œ ๋ช‡ ๊ฐ€์ง€ ์ค‘์š”ํ•œ ๋ณ€๊ฒฝ์‚ฌํ•ญ๋“ค์ด ์žˆ์Šต๋‹ˆ๋‹ค:

  • ๋ง์…ˆ์„ nn.quantized.FloatFunctional ์œผ๋กœ ๊ต์ฒด
  • ์‹ ๊ฒฝ๋ง์˜ ์ฒ˜์Œ๊ณผ ๋์— QuantStub ๋ฐ DeQuantStub ์‚ฝ์ž…
  • ReLU๋ฅผ ReLU6๋กœ ๊ต์ฒด

์•Œ๋ฆผ: ์ด ์ฝ”๋“œ๋Š” ์—ฌ๊ธฐ ์—์„œ ๊ฐ€์ ธ์™”์Šต๋‹ˆ๋‹ค.

from torch.ao.quantization import QuantStub, DeQuantStub

def _make_divisible(v, divisor, min_value=None):
    """
    ์ด ํ•จ์ˆ˜๋Š” ์›๋ณธ TensorFlow ์ €์žฅ์†Œ์—์„œ ๊ฐ€์ ธ์™”์Šต๋‹ˆ๋‹ค.
    ๋ชจ๋“  ๊ณ„์ธต์ด 8๋กœ ๋‚˜๋ˆ„์–ด์ง€๋Š” ์ฑ„๋„ ์ˆซ์ž๋ฅผ ๊ฐ€์ง€๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค.
    ์ด๊ณณ์—์„œ ํ™•์ธ ๊ฐ€๋Šฅํ•ฉ๋‹ˆ๋‹ค:
    https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
    :param v:
    :param divisor:
    :param min_value:
    :return:
    """
    if min_value is None:
        min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    # ๋‚ด๋ฆผ์€ 10% ๋„˜๊ฒŒ ๋‚ด๋ ค๊ฐ€์ง€ ์•Š๋Š” ๊ฒƒ์„ ๋ณด์žฅํ•ฉ๋‹ˆ๋‹ค.
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v


class ConvBNReLU(nn.Sequential):
    def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
        padding = (kernel_size - 1) // 2
        super(ConvBNReLU, self).__init__(
            nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
            nn.BatchNorm2d(out_planes, momentum=0.1),
            # ReLU๋กœ ๊ต์ฒด
            nn.ReLU(inplace=False)
        )


class InvertedResidual(nn.Module):
    def __init__(self, inp, oup, stride, expand_ratio):
        super(InvertedResidual, self).__init__()
        self.stride = stride
        assert stride in [1, 2]

        hidden_dim = int(round(inp * expand_ratio))
        self.use_res_connect = self.stride == 1 and inp == oup

        layers = []
        if expand_ratio != 1:
            # pw
            layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
        layers.extend([
            # dw
            ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
            # pw-linear
            nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
            nn.BatchNorm2d(oup, momentum=0.1),
        ])
        self.conv = nn.Sequential(*layers)
        # torch.add๋ฅผ floatfunctional๋กœ ๊ต์ฒด
        self.skip_add = nn.quantized.FloatFunctional()

    def forward(self, x):
        if self.use_res_connect:
            return self.skip_add.add(x, self.conv(x))
        else:
            return self.conv(x)


class MobileNetV2(nn.Module):
    def __init__(self, num_classes=1000, width_mult=1.0, inverted_residual_setting=None, round_nearest=8):
        """
        MobileNet V2 ๋ฉ”์ธ ํด๋ž˜์Šค
        Args:
            num_classes (int): ํด๋ž˜์Šค ์ˆซ์ž
            width_mult (float): ๋„“์ด multiplier - ์ด ์ˆ˜๋ฅผ ํ†ตํ•ด ๊ฐ ๊ณ„์ธต์˜ ์ฑ„๋„ ๊ฐœ์ˆ˜๋ฅผ ์กฐ์ ˆ
            inverted_residual_setting: ๋„คํŠธ์›Œํฌ ๊ตฌ์กฐ
            round_nearest (int): ๊ฐ ๊ณ„์ธต์˜ ์ฑ„๋„ ์ˆซ๋ฅผ ์ด ์ˆซ์ž์˜ ๋ฐฐ์ˆ˜๋กœ ๋ฐ˜์˜ฌ๋ฆผ
            1๋กœ ์„ค์ •ํ•˜๋ฉด ๋ฐ˜์˜ฌ๋ฆผ ์ •์ง€
        """
        super(MobileNetV2, self).__init__()
        block = InvertedResidual
        input_channel = 32
        last_channel = 1280

        if inverted_residual_setting is None:
            inverted_residual_setting = [
                # t, c, n, s
                [1, 16, 1, 1],
                [6, 24, 2, 2],
                [6, 32, 3, 2],
                [6, 64, 4, 2],
                [6, 96, 3, 1],
                [6, 160, 3, 2],
                [6, 320, 1, 1],
            ]

        # ์‚ฌ์šฉ์ž๊ฐ€ t,c,n,s๋ฅผ ํ•„์š”ํ•˜๋‹ค๋Š” ๊ฒƒ์„ ์•ˆ๋‹ค๋Š” ์ „์ œํ•˜์— ์ฒซ ๋ฒˆ์งธ ์š”์†Œ๋งŒ ํ™•์ธ
        if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
            raise ValueError("inverted_residual_setting should be non-empty "
                             "or a 4-element list, got {}".format(inverted_residual_setting))

        # ์ฒซ ๋ฒˆ์งธ ๊ณ„์ธต ๋งŒ๋“ค๊ธฐ
        input_channel = _make_divisible(input_channel * width_mult, round_nearest)
        self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
        features = [ConvBNReLU(3, input_channel, stride=2)]
        # ์—ญ์ „๋œ ์ž”์ฐจ ๋ธ”๋Ÿญ(inverted residual blocks) ๋งŒ๋“ค๊ธฐ
        for t, c, n, s in inverted_residual_setting:
            output_channel = _make_divisible(c * width_mult, round_nearest)
            for i in range(n):
                stride = s if i == 0 else 1
                features.append(block(input_channel, output_channel, stride, expand_ratio=t))
                input_channel = output_channel
        # ๋งˆ์ง€๋ง‰ ๊ณ„์ธต๋“ค ๋งŒ๋“ค๊ธฐ
        features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1))
        # nn.Sequential๋กœ ๋งŒ๋“ค๊ธฐ
        self.features = nn.Sequential(*features)
        self.quant = QuantStub()
        self.dequant = DeQuantStub()
        # ๋ถ„๋ฅ˜๊ธฐ(classifier) ๋งŒ๋“ค๊ธฐ
        self.classifier = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(self.last_channel, num_classes),
        )

        # ๊ฐ€์ค‘์น˜ ์ดˆ๊ธฐํ™”
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.zeros_(m.bias)

    def forward(self, x):
        x = self.quant(x)
        x = self.features(x)
        x = x.mean([2, 3])
        x = self.classifier(x)
        x = self.dequant(x)
        return x

    # ์–‘์žํ™” ์ „์— Conv+BN๊ณผ Conv+BN+Relu ๋ชจ๋“ˆ ๊ฒฐํ•ฉ(fusion)
    # ์ด ์—ฐ์‚ฐ์€ ์ˆซ์ž๋ฅผ ๋ณ€๊ฒฝํ•˜์ง€ ์•Š์Œ
    def fuse_model(self, is_qat=False):
        fuse_modules = torch.ao.quantization.fuse_modules_qat if is_qat else torch.ao.quantization.fuse_modules
        for m in self.modules():
            if type(m) == ConvBNReLU:
                fuse_modules(m, ['0', '1', '2'], inplace=True)
            if type(m) == InvertedResidual:
                for idx in range(len(m.conv)):
                    if type(m.conv[idx]) == nn.Conv2d:
                        fuse_modules(m.conv, [str(idx), str(idx + 1)], inplace=True)

2. ํ—ฌํผ(Helper) ํ•จ์ˆ˜

๋‹ค์Œ์œผ๋กœ ๋ชจ๋ธ ํ‰๊ฐ€๋ฅผ ์œ„ํ•œ ํ—ฌํผ ํ•จ์ˆ˜๋“ค์„ ๋งŒ๋“ญ๋‹ˆ๋‹ค. ์ฝ”๋“œ ๋Œ€๋ถ€๋ถ„์€ ์—ฌ๊ธฐ ์—์„œ ๊ฐ€์ ธ์™”์Šต๋‹ˆ๋‹ค.

class AverageMeter(object):
    """ํ‰๊ท ๊ณผ ํ˜„์žฌ ๊ฐ’ ๊ณ„์‚ฐ ๋ฐ ์ €์žฅ"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


def accuracy(output, target, topk=(1,)):
    """ํŠน์ • k๊ฐ’์„ ์œ„ํ•ด top k ์˜ˆ์ธก์˜ ์ •ํ™•๋„ ๊ณ„์‚ฐ"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


def evaluate(model, criterion, data_loader, neval_batches):
    model.eval()
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    cnt = 0
    with torch.no_grad():
        for image, target in data_loader:
            output = model(image)
            loss = criterion(output, target)
            cnt += 1
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            print('.', end = '')
            top1.update(acc1[0], image.size(0))
            top5.update(acc5[0], image.size(0))
            if cnt >= neval_batches:
                 return top1, top5

    return top1, top5

def load_model(model_file):
    model = MobileNetV2()
    state_dict = torch.load(model_file)
    model.load_state_dict(state_dict)
    model.to('cpu')
    return model

def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p")/1e6)
    os.remove('temp.p')

3. Dataset๊ณผ DataLoader ์ •์˜ํ•˜๊ธฐ

๋งˆ์ง€๋ง‰ ์ฃผ์š” ์„ค์ • ๋‹จ๊ณ„๋กœ์„œ ํ•™์Šต๊ณผ ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ๋ฅผ ์œ„ํ•œ DataLoader๋ฅผ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค.

ImageNet ๋ฐ์ดํ„ฐ

์ „์ฒด ImageNet Dataset์„ ์ด์šฉํ•ด์„œ ์ด ํŠœํ† ๋ฆฌ์–ผ์˜ ์ฝ”๋“œ๋ฅผ ์‹คํ–‰์‹œํ‚ค๊ธฐ ์œ„ํ•ด, ์ฒซ๋ฒˆ์งธ๋กœ ImageNet Data ์˜ ์ง€์‹œ๋ฅผ ๋”ฐ๋ผ ImageNet์„ ๋‹ค์šด๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค. ๋‹ค์šด๋กœ๋“œํ•œ ํŒŒ์ผ์˜ ์••์ถ•์„ 'data_path'์— ํ’‰๋‹ˆ๋‹ค.

๋‹ค์šด๋กœ๋“œ๋ฐ›์€ ๋ฐ์ดํ„ฐ๋ฅผ ์ฝ๊ธฐ ์œ„ํ•ด ์•„๋ž˜์— ์ •์˜๋œ DataLoader ํ•จ์ˆ˜๋“ค์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค. ์ด๋Ÿฐ ํ•จ์ˆ˜๋“ค ๋Œ€๋ถ€๋ถ„์€ ์—ฌ๊ธฐ ์—์„œ ๊ฐ€์ ธ์™”์Šต๋‹ˆ๋‹ค.

def prepare_data_loaders(data_path):

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    dataset = torchvision.datasets.ImageNet(
        data_path, split="train", transform=transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))
    dataset_test = torchvision.datasets.ImageNet(
        data_path, split="val", transform=transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ]))

    train_sampler = torch.utils.data.RandomSampler(dataset)
    test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    data_loader = torch.utils.data.DataLoader(
        dataset, batch_size=train_batch_size,
        sampler=train_sampler)

    data_loader_test = torch.utils.data.DataLoader(
        dataset_test, batch_size=eval_batch_size,
        sampler=test_sampler)

    return data_loader, data_loader_test

๋‹ค์Œ์œผ๋กœ ์‚ฌ์ „์— ํ•™์Šต๋œ MobileNetV2์„ ๋ถˆ๋Ÿฌ์˜ต๋‹ˆ๋‹ค. ๋ชจ๋ธ์„ ๋‹ค์šด๋กœ๋“œ ๋ฐ›์„ ์ˆ˜ ์žˆ๋Š” URL์„ `์—ฌ๊ธฐ <<https://download.pytorch.org/models/mobilenet_v2-b0353104.pth>>`_ ์—์„œ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค.

data_path = '~/.data/imagenet'
saved_model_dir = 'data/'
float_model_file = 'mobilenet_pretrained_float.pth'
scripted_float_model_file = 'mobilenet_quantization_scripted.pth'
scripted_quantized_model_file = 'mobilenet_quantization_scripted_quantized.pth'

train_batch_size = 30
eval_batch_size = 50

data_loader, data_loader_test = prepare_data_loaders(data_path)
criterion = nn.CrossEntropyLoss()
float_model = load_model(saved_model_dir + float_model_file).to('cpu')

# ๋‹ค์Œ์œผ๋กœ "๋ชจ๋“ˆ ๊ฒฐํ•ฉ"์„ ํ•ฉ๋‹ˆ๋‹ค. ๋ชจ๋“ˆ ๊ฒฐํ•ฉ์€ ๋ฉ”๋ชจ๋ฆฌ ์ ‘๊ทผ์„ ์ค„์—ฌ ๋ชจ๋ธ์„ ๋น ๋ฅด๊ฒŒ ๋งŒ๋“ค๋ฉด์„œ
# ์ •ํ™•๋„ ์ˆ˜์น˜๋ฅผ ํ–ฅ์ƒ์‹œํ‚ต๋‹ˆ๋‹ค. ๋ชจ๋“ˆ ๊ฒฐํ•ฉ์€ ์–ด๋– ํ•œ ๋ชจ๋ธ์—๋ผ๋„ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์ง€๋งŒ,
# ์–‘์žํ™”๋œ ๋ชจ๋ธ์— ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์ด ํŠนํžˆ๋‚˜ ๋” ์ผ๋ฐ˜์ ์ž…๋‹ˆ๋‹ค.

print('\n Inverted Residual Block: Before fusion \n\n', float_model.features[1].conv)
float_model.eval()

# ๋ชจ๋“ˆ ๊ฒฐํ•ฉ
float_model.fuse_model()

# Conv+BN+Relu์™€ Conv+Relu ๊ฒฐํ•ฉ์— ์œ ์˜
print('\n Inverted Residual Block: After fusion\n\n',float_model.features[1].conv)

๋งˆ์ง€๋ง‰์œผ๋กœ "๊ธฐ์ค€"์ด ๋  ์ •ํ™•๋„๋ฅผ ์–ป๊ธฐ ์œ„ํ•ด, ๋ชจ๋“ˆ ๊ฒฐํ•ฉ์„ ์‚ฌ์šฉํ•œ ์–‘์žํ™”๋˜์ง€ ์•Š์€ ๋ชจ๋ธ์˜ ์ •ํ™•๋„๋ฅผ ๋ด…์‹œ๋‹ค.

num_eval_batches = 1000

print("Size of baseline model")
print_size_of_model(float_model)

top1, top5 = evaluate(float_model, criterion, data_loader_test, neval_batches=num_eval_batches)
print('Evaluation accuracy on %d images, %2.2f'%(num_eval_batches * eval_batch_size, top1.avg))
torch.jit.save(torch.jit.script(float_model), saved_model_dir + scripted_float_model_file)

์ „์ฒด ๋ชจ๋ธ์€ 50,000๊ฐœ์˜ ์ด๋ฏธ์ง€๋ฅผ ๊ฐ€์ง„ eval ๋ฐ์ดํ„ฐ์…‹์—์„œ 71.9%์˜ ์ •ํ™•๋„๋ฅผ ๋ณด์ž…๋‹ˆ๋‹ค.

์ด ๊ฐ’์ด ๋น„๊ต๋ฅผ ์œ„ํ•œ ๊ธฐ์ค€์ด ๋  ๊ฒƒ์ž…๋‹ˆ๋‹ค. ๋‹ค์Œ์œผ๋กœ ์–‘์žํ™”๋œ ๋ชจ๋ธ์„ ๋ด…์‹œ๋‹ค.

4. ํ•™์Šต ํ›„ ์ •์  ์–‘์žํ™”(post-training static quantization)

ํ•™์Šต ํ›„ ์ •์  ์–‘์žํ™”๋Š” ๋™์  ์–‘์žํ™”์ฒ˜๋Ÿผ ๊ฐ€์ค‘์น˜๋ฅผ float์—์„œ int๋กœ ๋ณ€ํ™˜ํ•˜๋Š” ๊ฒƒ๋ฟ๋งŒ ์•„๋‹ˆ๋ผ ์ถ”๊ฐ€์ ์ธ ๋‹จ๊ณ„๋„ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค. ๋„คํŠธ์›Œํฌ์— ๋ฐ์ดํ„ฐ ๋ฐฐ์น˜์˜ ์ฒซ ๋ฒˆ์งธ ๊ณต๊ธ‰๊ณผ ๋‹ค๋ฅธ ํ™œ์„ฑ๊ฐ’๋“ค์˜ ๋ถ„ํฌ ๊ฒฐ๊ณผ ๊ณ„์‚ฐ์ด ์ด๋Ÿฌํ•œ ๋‹จ๊ณ„์ž…๋‹ˆ๋‹ค. (ํŠนํžˆ ์ด๋Ÿฌํ•œ ์ถ”๊ฐ€์ ์ธ ๋‹จ๊ณ„๋Š” ๊ณ„์‚ฐํ•œ ๊ฐ’์„ ๊ธฐ๋กํ•˜๊ณ  ์‹ถ์€ ์ง€์ ์— observer ๋ชจ๋“ˆ์„ ์‚ฝ์ž…ํ•ฉ์œผ๋กœ์จ ๋๋‚ฉ๋‹ˆ๋‹ค.) ์ด๋Ÿฌํ•œ ๋ถ„ํฌ๋“ค์€ ์ถ”๋ก  ์‹œ์ ์— ํŠน์ •ํ•œ ๋‹ค๋ฅธ ํ™œ์„ฑ๊ฐ’๋“ค์ด ์–ด๋–ป๊ฒŒ ์–‘์žํ™”๋˜์–ด์•ผ ํ•˜๋Š”์ง€ ๊ฒฐ์ •ํ•˜๋Š”๋ฐ ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค. (๊ฐ„๋‹จํ•œ ๋ฐฉ๋ฒ•์œผ๋กœ๋Š” ๋‹จ์ˆœํžˆ ํ™œ์„ฑ๊ฐ’๋“ค์˜ ์ „์ฒด ๋ฒ”์œ„๋ฅผ 256๊ฐœ์˜ ๋‹จ๊ณ„๋กœ ๋‚˜๋ˆ„๋Š” ๊ฒƒ์ด์ง€๋งŒ, ์ข€ ๋” ๋ณต์žกํ•œ ๋ฐฉ๋ฒ•๋„ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค.) ํŠนํžˆ, ์ด๋Ÿฌํ•œ ์ถ”๊ฐ€์ ์ธ ๋‹จ๊ณ„๋Š” ๊ฐ ์—ฐ์‚ฐ ์‚ฌ์ด์‚ฌ์ด์˜ ์–‘์žํ™”๋œ ๊ฐ’์„ float์œผ๋กœ ๋ณ€ํ™˜ - ๋ฐ int๋กœ ๋˜๋Œ๋ฆผ - ํ•˜๋Š” ๊ฒƒ๋ฟ๋งŒ ์•„๋‹ˆ๋ผ ์–‘์žํ™”๋œ ๊ฐ’์„ ๋ชจ๋“  ์—ฐ์‚ฐ๋“ค๋ผ๋ฆฌ ์ฃผ๊ณ  ๋ฐ›๋Š” ๊ฒƒ๋„ ๊ฐ€๋Šฅํ•˜๊ฒŒ ํ•˜์—ฌ ์—„์ฒญ๋‚œ ์†๋„ ํ–ฅ์ƒ์ด ๋ฉ๋‹ˆ๋‹ค.

num_calibration_batches = 32

myModel = load_model(saved_model_dir + float_model_file).to('cpu')
myModel.eval()

# Conv, bn๊ณผ relu ๊ฒฐํ•ฉ
myModel.fuse_model()

# ์–‘์žํ™” ์„ค์ • ๋ช…์‹œ
# ๊ฐ„๋‹จํ•œ min/max ๋ฒ”์œ„ ์ถ”์ • ๋ฐ ํ…์„œ๋ณ„ ๊ฐ€์ค‘์น˜ ์–‘์žํ™”๋กœ ์‹œ์ž‘
myModel.qconfig = torch.ao.quantization.default_qconfig
print(myModel.qconfig)
torch.ao.quantization.prepare(myModel, inplace=True)

# ์ฒซ ๋ฒˆ์งธ ๋ณด์ •(calibrate)
print('Post Training Quantization Prepare: Inserting Observers')
print('\n Inverted Residual Block:After observer insertion \n\n', myModel.features[1].conv)

# ํ•™์Šต ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ ๋ณด์ •(calibrate)
evaluate(myModel, criterion, data_loader, neval_batches=num_calibration_batches)
print('Post Training Quantization: Calibration done')

# ์–‘์žํ™”๋œ ๋ชจ๋ธ๋กœ ๋ณ€ํ™˜
torch.ao.quantization.convert(myModel, inplace=True)
# ๋ชจ๋ธ์„ ๋ณด์ •ํ•ด์•ผ ํ•œ๋‹ค(calibrate the model)๋Š” ์‚ฌ์šฉ์ž ๊ฒฝ๊ณ (user warning)๊ฐ€ ํ‘œ์‹œ๋  ์ˆ˜ ์žˆ์ง€๋งŒ ๋ฌด์‹œํ•ด๋„ ๋ฉ๋‹ˆ๋‹ค.
# ์ด ๊ฒฝ๊ณ ๋Š” ๊ฐ ๋ชจ๋ธ ์‹คํ–‰ ์‹œ ๋ชจ๋“  ๋ชจ๋“ˆ์ด ์‹คํ–‰๋˜๋Š” ๊ฒƒ์ด ์•„๋‹ˆ๊ธฐ ๋•Œ๋ฌธ์— ์ผ๋ถ€ ๋ชจ๋“ˆ์ด ๋ณด์ •๋˜์ง€ ์•Š์„ ์ˆ˜
# ์žˆ๋‹ค๋Š” ๊ฒฝ๊ณ ์ž…๋‹ˆ๋‹ค.
print('Post Training Quantization: Convert done')
print('\n Inverted Residual Block: After fusion and quantization, note fused modules: \n\n',myModel.features[1].conv)

print("Size of model after quantization")
print_size_of_model(myModel)

top1, top5 = evaluate(myModel, criterion, data_loader_test, neval_batches=num_eval_batches)
print('Evaluation accuracy on %d images, %2.2f'%(num_eval_batches * eval_batch_size, top1.avg))

์–‘์žํ™”๋œ ๋ชจ๋ธ์€ eval ๋ฐ์ดํ„ฐ์…‹์—์„œ 56.7%์˜ ์ •ํ™•๋„๋ฅผ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค. ์ด๋Š” ์–‘์žํ™” ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ๊ฒฐ์ •ํ•˜๊ธฐ ์œ„ํ•ด ๋‹จ์ˆœ min/max Observer๋ฅผ ์‚ฌ์šฉํ–ˆ๊ธฐ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค. ๊ทธ๋Ÿผ์—๋„ ๋ถˆ๊ตฌํ•˜๊ณ  ๋ชจ๋ธ์˜ ํฌ๊ธฐ๋ฅผ 3.6 MB ๋ฐ‘์œผ๋กœ ์ค„์˜€์Šต๋‹ˆ๋‹ค. ์ด๋Š” ๊ฑฐ์˜ 4๋ถ„์˜ 1 ๋กœ ์ค„์–ด๋“  ํฌ๊ธฐ์ž…๋‹ˆ๋‹ค.

์ด์— ๋”ํ•ด ๋‹จ์ˆœํžˆ ๋‹ค๋ฅธ ์–‘์žํ™” ์„ค์ •์„ ์‚ฌ์šฉํ•˜๊ธฐ๋งŒ ํ•ด๋„ ์ •ํ™•๋„๋ฅผ ํฐ ํญ์œผ๋กœ ํ–ฅ์ƒ์‹œํ‚ฌ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. x86 ์•„ํ‚คํ…์ฒ˜์—์„œ ์–‘์žํ™”๋ฅผ ์œ„ํ•œ ๊ถŒ์žฅ ์„ค์ •์„ ๊ทธ๋Œ€๋กœ ์“ฐ๊ธฐ๋งŒ ํ•ด๋„ ๋ฉ๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ์„ค์ •์€ ์•„๋ž˜์™€ ๊ฐ™์Šต๋‹ˆ๋‹ค:

  • ์ฑ„๋„๋ณ„ ๊ธฐ๋ณธ ๊ฐ€์ค‘์น˜ ์–‘์žํ™”
  • ํ™œ์„ฑ๊ฐ’์„ ์ˆ˜์ง‘ํ•ด์„œ ์ตœ์ ํ™”๋œ ์–‘์žํ™” ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ๊ณ ๋ฅด๋Š” ํžˆ์Šคํ† ๊ทธ๋žจ Observer ์‚ฌ์šฉ
per_channel_quantized_model = load_model(saved_model_dir + float_model_file)
per_channel_quantized_model.eval()
per_channel_quantized_model.fuse_model()
# ์ด์ „์˜ 'fbgemm' ๋˜ํ•œ ์—ฌ์ „ํžˆ ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•˜์ง€๋งŒ, 'x86'์„ ๊ธฐ๋ณธ์œผ๋กœ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์„ ๊ถŒ์žฅํ•ฉ๋‹ˆ๋‹ค.
per_channel_quantized_model.qconfig = torch.ao.quantization.get_default_qconfig('x86')
print(per_channel_quantized_model.qconfig)

torch.ao.quantization.prepare(per_channel_quantized_model, inplace=True)
evaluate(per_channel_quantized_model,criterion, data_loader, num_calibration_batches)
torch.ao.quantization.convert(per_channel_quantized_model, inplace=True)
top1, top5 = evaluate(per_channel_quantized_model, criterion, data_loader_test, neval_batches=num_eval_batches)
print('Evaluation accuracy on %d images, %2.2f'%(num_eval_batches * eval_batch_size, top1.avg))
torch.jit.save(torch.jit.script(per_channel_quantized_model), saved_model_dir + scripted_quantized_model_file)

๋‹จ์ˆœํžˆ ์–‘์žํ™” ์„ค์ • ๋ฐฉ๋ฒ•์„ ๋ณ€๊ฒฝํ•˜๋Š” ๊ฒƒ๋งŒ์œผ๋กœ๋„ ์ •ํ™•๋„๊ฐ€ 67.3%๋ฅผ ๋„˜์„ ์ •๋„๋กœ ํ–ฅ์ƒ์ด ๋˜์—ˆ์Šต๋‹ˆ๋‹ค! ๊ทธ๋Ÿผ์—๋„ ์ด ์ˆ˜์น˜๋Š” ์œ„์—์„œ ๊ตฌํ•œ ๊ธฐ์ค€๊ฐ’ 71.9%์—์„œ 4ํผ์„ผํŠธ๋‚˜ ๋‚ฎ์€ ์ˆ˜์น˜์ž…๋‹ˆ๋‹ค. ์ด์ œ ์–‘์žํ™” ์ž๊ฐ ํ•™์Šต์„ ์‹œ๋„ํ•ด ๋ด…์‹œ๋‹ค.

5. ์–‘์žํ™” ์ž๊ฐ ํ•™์Šต(Quantization-aware training)

์–‘์žํ™” ์ž๊ฐ ํ•™์Šต(QAT)์€ ์ผ๋ฐ˜์ ์œผ๋กœ ๊ฐ€์žฅ ๋†’์€ ์ •ํ™•๋„๋ฅผ ์ œ๊ณตํ•˜๋Š” ์–‘์žํ™” ๋ฐฉ๋ฒ•์ž…๋‹ˆ๋‹ค. ๋ชจ๋“  ๊ฐ€์ค‘์น˜ํ™” ํ™œ์„ฑ๊ฐ’์€ QAT๋กœ ์ธํ•ด ํ•™์Šต ๋„์ค‘์— ์ˆœ์ „ํŒŒ์™€ ์—ญ์ „ํŒŒ๋ฅผ ๋„์ค‘ "๊ฐ€์งœ ์–‘์žํ™”"๋ฉ๋‹ˆ๋‹ค. ์ด๋Š” float๊ฐ’์ด int8 ๊ฐ’์œผ๋กœ ๋ฐ˜์˜ฌ๋ฆผํ•˜๋Š” ๊ฒƒ์ฒ˜๋Ÿผ ํ‰๋‚ด๋ฅผ ๋‚ด์ง€๋งŒ, ๋ชจ๋“  ๊ณ„์‚ฐ์€ ์—ฌ์ „ํžˆ ๋ถ€๋™์†Œ์ˆ˜์  ์ˆซ์ž๋กœ ๊ณ„์‚ฐ์„ ํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋ž˜์„œ ๊ฒฐ๊ตญ ํ›ˆ๋ จ ๋™์•ˆ์˜ ๋ชจ๋“  ๊ฐ€์ค‘์น˜ ์กฐ์ •์€ ๋ชจ๋ธ์ด ์–‘์žํ™”๋  ๊ฒƒ์ด๋ผ๋Š” ์‚ฌ์‹ค์„ "์ž๊ฐ"ํ•œ ์ฑ„๋กœ ์ด๋ฃจ์–ด์ง€๊ฒŒ ๋ฉ๋‹ˆ๋‹ค. ๊ทธ๋ž˜์„œ QAT๋Š” ์–‘์žํ™”๊ฐ€ ์ด๋ฃจ์–ด์ง€๊ณ  ๋‚˜๋ฉด ๋™์  ์–‘์žํ™”๋‚˜ ํ•™์Šต ์ „ ์ •์  ์–‘์žํ™”๋ณด๋‹ค ๋Œ€์ฒด๋กœ ๋” ๋†’์€ ์ •ํ™•๋„๋ฅผ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค.

์‹ค์ œ๋กœ QAT๊ฐ€ ์ด๋ฃจ์–ด์ง€๋Š” ์ „์ฒด ํ๋ฆ„์€ ์ด์ „๊ณผ ๋งค์šฐ ์œ ์‚ฌํ•ฉ๋‹ˆ๋‹ค:

  • ์ด์ „๊ณผ ๊ฐ™์€ ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์–‘์žํ™” ์ž๊ฐ ํ•™์Šต์„ ์œ„ํ•œ ์ถ”๊ฐ€์ ์ธ ์ค€๋น„๋Š” ํ•„์š” ์—†์Šต๋‹ˆ๋‹ค.
  • ๊ฐ€์ค‘์น˜์™€ ํ™œ์„ฑ๊ฐ’ ๋’ค์— ์–ด๋–ค ์ข…๋ฅ˜์˜ ๊ฐ€์งœ ์–‘์žํ™”๋ฅผ ์‚ฌ์šฉํ•  ๊ฒƒ์ธ์ง€ ๋ช…์‹œํ•˜๋Š” qconfig ์˜ ์‚ฌ์šฉ์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค. Observer๋ฅผ ๋ช…์‹œํ•˜๋Š” ๊ฒƒ ๋Œ€์‹ ์— ๋ง์ด์ฃ .

๋จผ์ € ํ•™์Šต ํ•จ์ˆ˜๋ถ€ํ„ฐ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค:

def train_one_epoch(model, criterion, optimizer, data_loader, device, ntrain_batches):
    model.train()
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    avgloss = AverageMeter('Loss', '1.5f')

    cnt = 0
    for image, target in data_loader:
        start_time = time.time()
        print('.', end = '')
        cnt += 1
        image, target = image.to(device), target.to(device)
        output = model(image)
        loss = criterion(output, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        top1.update(acc1[0], image.size(0))
        top5.update(acc5[0], image.size(0))
        avgloss.update(loss, image.size(0))
        if cnt >= ntrain_batches:
            print('Loss', avgloss.avg)

            print('Training: * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
                  .format(top1=top1, top5=top5))
            return

    print('Full imagenet train set:  * Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f}'
          .format(top1=top1, top5=top5))
    return

์ด์ „์ฒ˜๋Ÿผ ๋ชจ๋“ˆ์„ ๊ฒฐํ•ฉํ•ฉ๋‹ˆ๋‹ค.

qat_model = load_model(saved_model_dir + float_model_file)
qat_model.fuse_model(is_qat=True)

optimizer = torch.optim.SGD(qat_model.parameters(), lr = 0.0001)
# ์ด์ „์˜ 'fbgemm' ๋˜ํ•œ ์—ฌ์ „ํžˆ ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•˜์ง€๋งŒ, 'x86'์„ ๊ธฐ๋ณธ์œผ๋กœ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์„ ๊ถŒ์žฅํ•ฉ๋‹ˆ๋‹ค.
qat_model.qconfig = torch.ao.quantization.get_default_qat_qconfig('x86')

๋งˆ์ง€๋ง‰์œผ๋กœ ๋ชจ๋ธ์ด ์–‘์žํ™” ์ž๊ฐ ํ•™์Šต์„ ์ค€๋น„ํ•˜๊ธฐ ์œ„ํ•ด prepare_qat ๋กœ "๊ฐ€์งœ ์–‘์žํ™”"๋ฅผ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.

torch.ao.quantization.prepare_qat(qat_model, inplace=True)
print('Inverted Residual Block: After preparation for QAT, note fake-quantization modules \n',qat_model.features[1].conv)

๋†’์€ ์ •ํ™•๋„์˜ ์–‘์žํ™”๋œ ๋ชจ๋ธ์„ ํ•™์Šต์‹œํ‚ค๊ธฐ ์œ„ํ•ด์„œ๋Š” ์ถ”๋ก  ์‹œ์ ์—์„œ ์ •ํ™•ํ•œ ์ˆซ์ž ๋ชจ๋ธ๋ง์„ ํ•„์š”๋กœ ํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋ž˜์„œ ์–‘์žํ™” ์ž๊ฐ ํ•™์Šต์—์„œ๋Š” ํ•™์Šต ๋ฃจํ”„๋ฅผ ์ด๋ ‡๊ฒŒ ๋ณ€๊ฒฝํ•ฉ๋‹ˆ๋‹ค:

  • ์ถ”๋ก  ์ˆ˜์น˜์™€ ๋” ์ž˜ ์ผ์น˜ํ•˜๋„๋ก ํ•™์Šต์ด ๋๋‚  ๋•Œ ๋ฐฐ์น˜ ์ •๊ทœํ™”๋ฅผ ์ด๋™ ํ‰๊ท ๊ณผ ๋ถ„์‚ฐ์„ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์œผ๋กœ ๋ณ€๊ฒฝํ•ฉ๋‹ˆ๋‹ค.
  • ์–‘์žํ™” ํŒŒ๋ผ๋ฏธํ„ฐ(ํฌ๊ธฐ์™€ ์˜์ )๋ฅผ ๊ณ ์ •ํ•˜๊ณ  ๊ฐ€์ค‘์น˜๋ฅผ ๋ฏธ์„ธ ์กฐ์ •(fine tune)ํ•ฉ๋‹ˆ๋‹ค.
num_train_batches = 20

# QAT๋Š” ์‹œ๊ฐ„์ด ๊ฑธ๋ฆฌ๋Š” ์ž‘์—…์ด๋ฉฐ ๋ช‡ ์—ํญ์— ๊ฑธ์ณ ํ›ˆ๋ จ์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.
# ํ•™์Šต ๋ฐ ๊ฐ ์—ํญ ์ดํ›„ ์ •ํ™•๋„ ํ™•์ธ
for nepoch in range(8):
    train_one_epoch(qat_model, criterion, optimizer, data_loader, torch.device('cpu'), num_train_batches)
    if nepoch > 3:
        # ์–‘์žํ™” ํŒŒ๋ผ๋ฏธํ„ฐ ๊ณ ์ •
        qat_model.apply(torch.ao.quantization.disable_observer)
    if nepoch > 2:
        # ๋ฐฐ์น˜ ์ •๊ทœํ™” ํ‰๊ท  ๋ฐ ๋ถ„์‚ฐ ์ถ”์ •๊ฐ’ ๊ณ ์ •
        qat_model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)

    # ๊ฐ ์—ํญ ์ดํ›„ ์ •ํ™•๋„ ํ™•์ธ
    quantized_model = torch.ao.quantization.convert(qat_model.eval(), inplace=False)
    quantized_model.eval()
    top1, top5 = evaluate(quantized_model,criterion, data_loader_test, neval_batches=num_eval_batches)
    print('Epoch %d :Evaluation accuracy on %d images, %2.2f'%(nepoch, num_eval_batches * eval_batch_size, top1.avg))

์–‘์žํ™” ์ž๊ฐ ํ•™์Šต์€ ์ „์ฒด ImageNet ๋ฐ์ดํ„ฐ์…‹์—์„œ 71.5%์˜ ์ •ํ™•๋„๋ฅผ ๋‚˜ํƒ€๋ƒ…๋‹ˆ๋‹ค. ์ด ๊ฐ’์€ ๊ธฐ์ค€๊ฐ’ 71.9%์— ์†Œ์ˆ˜์  ์ˆ˜์ค€์œผ๋กœ ๊ทผ์ ‘ํ•œ ์ˆ˜์น˜์ž…๋‹ˆ๋‹ค.

์–‘์žํ™” ์ž๊ฐ ํ•™์Šต์— ๋Œ€ํ•œ ๋” ๋งŽ์€ ๊ฒƒ๋“ค:

  • QAT๋Š” ๋” ๋งŽ์€ ๋””๋ฒ„๊น…์„ ๊ฐ€๋Šฅํ•˜๊ฒŒ ํ•˜๋Š” ํ•™์Šต ํ›„ ์–‘์žํ™” ๊ธฐ์ˆ ์˜ ์ƒ์œ„ ์ง‘ํ•ฉ์ž…๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด ๋ชจ๋ธ์˜ ์ •ํ™•๋„๊ฐ€ ๊ฐ€์ค‘์น˜๋‚˜ ํ™œ์„ฑ ์–‘์žํ™”๋กœ ์ธํ•ด ์ œํ•œ์„ ๋ฐ›์•„ ๋” ๋†’์•„์งˆ ์ˆ˜ ์—†๋Š” ์ƒํ™ฉ์ธ์ง€ ๋ถ„์„ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
  • ๋ถ€๋™์†Œ์ˆ˜์ ์„ ์‚ฌ์šฉํ•œ ์–‘์žํ™”๋œ ๋ชจ๋ธ์„ ์‹œ๋ฎฌ๋ ˆ์ด์…˜ ํ•  ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค. ์‹ค์ œ ์–‘์žํ™”๋œ ์—ฐ์‚ฐ์˜ ์ˆ˜์น˜๋ฅผ ๋ชจ๋ธ๋งํ•˜๊ธฐ ์œ„ํ•ด ๊ฐ€์งœ ์–‘์žํ™”๋ฅผ ์ด์šฉํ•˜๊ณ  ์žˆ๊ธฐ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค.
  • ํ•™์Šต ํ›„ ์–‘์žํ™” ๋˜ํ•œ ์‰ฝ๊ฒŒ ํ‰๋‚ด๋‚ผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์–‘์žํ™”๋ฅผ ํ†ตํ•œ ์†๋„ ํ–ฅ์ƒ

๋งˆ์ง€๋ง‰์œผ๋กœ ์œ„์—์„œ ์–ธ๊ธ‰ํ•œ ๊ฒƒ๋“ค์„ ํ™•์ธํ•ด ๋ด…์‹œ๋‹ค. ์–‘์žํ™”๋œ ๋ชจ๋ธ์ด ์‹ค์ œ๋กœ ์ถ”๋ก ๋„ ๋” ๋น ๋ฅด๊ฒŒ ํ•˜๋Š” ๊ฑธ๊นŒ์š”? ์‹œํ—˜ํ•ด ๋ด…์‹œ๋‹ค:

def run_benchmark(model_file, img_loader):
    elapsed = 0
    model = torch.jit.load(model_file)
    model.eval()
    num_batches = 5
    # ์ด๋ฏธ์ง€ ๋ฐฐ์น˜๋“ค ์ด์šฉํ•˜์—ฌ ์Šคํฌ๋ฆฝํŠธ๋œ ๋ชจ๋ธ ์‹คํ–‰
    for i, (images, target) in enumerate(img_loader):
        if i < num_batches:
            start = time.time()
            output = model(images)
            end = time.time()
            elapsed = elapsed + (end-start)
        else:
            break
    num_images = images.size()[0] * num_batches

    print('Elapsed time: %3.0f ms' % (elapsed/num_images*1000))
    return elapsed

run_benchmark(saved_model_dir + scripted_float_model_file, data_loader_test)

run_benchmark(saved_model_dir + scripted_quantized_model_file, data_loader_test)

๋งฅ๋ถ ํ”„๋กœ์˜ ๋กœ์ปฌ ํ™˜๊ฒฝ์—์„œ ์ผ๋ฐ˜์ ์ธ ๋ชจ๋ธ ์‹คํ–‰์€ 61ms, ์–‘์žํ™”๋œ ๋ชจ๋ธ ์‹คํ–‰์€ 20ms๊ฐ€ ๊ฑธ๋ ธ์Šต๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ๊ฒฐ๊ณผ๋Š” ๋ถ€๋™์†Œ์ˆ˜์  ๋ชจ๋ธ๊ณผ ์–‘์žํ™”๋œ ๋ชจ๋ธ์„ ๋น„๊ตํ–ˆ์„ ๋•Œ, ์–‘์žํ™”๋œ ๋ชจ๋ธ์—์„œ ์ผ๋ฐ˜์ ์œผ๋กœ 2-4x ์†๋„ ํ–ฅ์ƒ์ด ์ด๋ฃจ์–ด์ง„ ๊ฒƒ์„ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค.

๊ฒฐ๋ก 

์ด ํŠœํ† ๋ฆฌ์–ผ์—์„œ ํ•™์Šต ํ›„ ์ •์  ์–‘์žํ™”์™€ ์–‘์žํ™” ์ž๊ฐ ํ•™์Šต์ด๋ผ๋Š” ๋‘ ๊ฐ€์ง€ ์–‘์žํ™” ๋ฐฉ๋ฒ•์„ ์‚ดํŽด๋ดค์Šต๋‹ˆ๋‹ค. ์ด ์–‘์žํ™” ๋ฐฉ๋ฒ•๋“ค์ด "๋‚ด๋ถ€์ ์œผ๋กœ" ์–ด๋–ป๊ฒŒ ๋™์ž‘์„ ํ•˜๋Š”์ง€์™€ PyTorch์—์„œ ์–ด๋–ป๊ฒŒ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋Š”์ง€๋„ ๋ณด์•˜์Šต๋‹ˆ๋‹ค.

์ฝ์–ด์ฃผ์…”์„œ ๊ฐ์‚ฌํ•ฉ๋‹ˆ๋‹ค. ์–ธ์ œ๋‚˜์ฒ˜๋Ÿผ ์–ด๋– ํ•œ ํ”ผ๋“œ๋ฐฑ๋„ ํ™˜์˜์ด๋‹ˆ, ์˜๊ฒฌ์ด ์žˆ๋‹ค๋ฉด ์—ฌ๊ธฐ ์— ์ด์Šˆ๋ฅผ ๋‚จ๊ฒจ ์ฃผ์„ธ์š”.