In [1]:
import numpy as np
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
from torchvision import datasets
import torchvision.transforms as transforms
import os
import time
import sys
import torch.quantization

# # Setup warnings
import warnings
warnings.filterwarnings(
    action='ignore',
    category=DeprecationWarning,
    module=r'.*'
)
warnings.filterwarnings(
    action='default',
    module=r'torch.quantization'
)

# Specify random seed for repeatable results
torch.manual_seed(191009)

<torch._C.Generator at 0x12482f8f0>

In [2]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


def evaluate(model, criterion, data_loader, neval_batches):
    model.eval()
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    cnt = 0
    with torch.no_grad():
        for image, target in data_loader:
            output = model(image)
            loss = criterion(output, target)
            cnt += 1
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            print('.', end = '')
            top1.update(acc1[0], image.size(0))
            top5.update(acc5[0], image.size(0))
            if cnt >= neval_batches:
                 return top1, top5
    print('Evaluation accuracy on %d images, %2.2f'%(neval_batches * eval_batch_size, top1.avg))

    return top1, top5

def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch):
    model.train()
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    header = 'Epoch: [{}]'.format(epoch)
    for image, target in data_loader:
        start_time = time.time()
        image, target = image.to(device), target.to(device)
        output = model(image)
        loss = criterion(output, target)

        optimizer.zero_grad()
        
        loss.backward()
        optimizer.step()

        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        batch_size = image.shape[0]
        print('.', end = '')
        top1.update(acc1[0], batch_size)
        top5.update(acc5[0], batch_size)
    
    print('Training accuracy , %2.2f'%(top1.avg))

    return top1, top5



def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p")/1e6)
    os.remove('temp.p')

## 数据集

In [3]:
def prepare_data_loaders(data_path):

    traindir = os.path.join(data_path, 'train')
    valdir = os.path.join(data_path, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    dataset = torchvision.datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    dataset_test = torchvision.datasets.ImageFolder(
        valdir,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ]))

    train_sampler = torch.utils.data.RandomSampler(dataset)
    test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    data_loader = torch.utils.data.DataLoader(
        dataset, batch_size=train_batch_size,
        sampler=train_sampler)

    data_loader_test = torch.utils.data.DataLoader(
        dataset_test, batch_size=eval_batch_size,
        sampler=test_sampler)

    return data_loader, data_loader_test

In [4]:
data_path = '/Users/yizuotian/dataset/imagenet_1k'  # https://s3.amazonaws.com/pytorch-tutorial-assets/imagenet_1k.zip
saved_model_dir = '/Users/yizuotian/pretrained_model/'
float_model_file = 'mobilenet_v2-b0353104.pth' #'mobilenet_pretrained_float.pth'
scripted_float_model_file = 'mobilenet_quantization_scripted.pth'
scripted_quantized_model_file = 'mobilenet_quantization_scripted_quantized.pth'

train_batch_size = 30
eval_batch_size = 30

data_loader, data_loader_test = prepare_data_loaders(data_path)
criterion = nn.CrossEntropyLoss()


# shuffle net v2测试

In [5]:
def _replace_relu(module):
    reassign = {}
    for name, mod in module.named_children():
        _replace_relu(mod)
        # Checking for explicit type instead of instance
        # as we only want to replace modules of the exact type
        # not inherited classes
        if type(mod) == nn.ReLU or type(mod) == nn.ReLU6:
            reassign[name] = nn.ReLU(inplace=False)

    for key, value in reassign.items():
        module._modules[key] = value

In [6]:
# 'shufflenetv2_x1.0', False, False, True,
m = torchvision.models.quantization.QuantizableShuffleNetV2(
                         [4, 8, 4], [24, 116, 232, 464, 1024])
m.load_state_dict(torch.load('/Users/yizuotian/pretrained_model/shufflenetv2_x1-5666bf0f80.pth'))
m.to('cpu')
_replace_relu(m)
# Fuse Conv, bn and relu
m.fuse_model()

m.qconfig = torch.quantization.get_default_qat_qconfig('qnnpack')
torch.quantization.prepare_qat(m, inplace=True)

optimizer = torch.optim.SGD(
    m.parameters(), lr=1e-3, momentum=0.9,
    weight_decay=1e-4)

m.apply(torch.quantization.enable_observer)
m.apply(torch.quantization.enable_fake_quant)

QuantizableShuffleNetV2(
  (conv1): Sequential(
    (0): ConvBnReLU2d(
      3, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False
      (activation_post_process): FakeQuantize(
        fake_quant_enabled=True, observer_enabled=True,            scale=None, zero_point=None
        (activation_post_process): MovingAverageMinMaxObserver(min_val=None, max_val=None)
      )
      (weight_fake_quant): FakeQuantize(
        fake_quant_enabled=True, observer_enabled=True,            scale=None, zero_point=None
        (activation_post_process): MovingAverageMinMaxObserver(min_val=None, max_val=None)
      )
    )
    (1): Identity()
    (2): Identity()
  )
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (stage2): Sequential(
    (0): QuantizableInvertedResidual(
      (branch1): Sequential(
        (0): ConvBn2d(
          24, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=24, bias=False
          (activation_post_process)

## 训练

In [25]:
import copy
for epoch in range(5):
    train_one_epoch(m, criterion, optimizer, data_loader, 'cpu', epoch)
    
    with torch.no_grad():
        if epoch >= 4: #args.num_observer_update_epochs:
            print('Disabling observer for subseq epochs, epoch = ', epoch)
            m.apply(torch.quantization.disable_observer)
        if epoch >= 3:#args.num_batch_norm_update_epochs:
            print('Freezing BN for subseq epochs, epoch = ', epoch)
            m.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
        print('Evaluate QAT model')

        evaluate(m, criterion, data_loader_test, 1000)
        quantized_eval_model = copy.deepcopy(m)
        quantized_eval_model.eval()
        quantized_eval_model.to(torch.device('cpu'))
        torch.quantization.convert(quantized_eval_model, inplace=True)

        print('Evaluate Quantized model')
        evaluate(quantized_eval_model, criterion, data_loader_test,
                 1000)



..................................Training accuracy , 78.60
Evaluate QAT model
..................................Evaluation accuracy on 30000 images, 63.50
Evaluate Quantized model
..................................Evaluation accuracy on 30000 images, 59.90


In [None]:
## 保存

In [26]:
torch.jit.save(torch.jit.script(quantized_eval_model), saved_model_dir + 'shufflenetv2_x1_qnnpack_aware_trainning.pth')