In [14]:
%load_ext autoreload

%autoreload 2
import sys
sys.path.append("..")
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import myobserver
import myfake_quantize
import mynet
from pathlib import Path
import numpy as np
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision import transforms
from PIL import Image
import re
from sklearn.metrics import accuracy_score, confusion_matrix
import torch.quantization as tq
import os

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [15]:
def init_parameters(layer):
    if type(layer) == (nn.Linear or nn.Conv2d):
        nn.init.xavier_uniform_(layer.weight) # 重みを「一様分布のランダム値」に初期化
        layer.bias.data.fill_(0.0)            # バイアスを「0」に初期化

def train(model, device, train_loader, loss_func, optimizer):
  total_acc = 0
  total_loss = 0
  model.train()
  for batch_idx, (data, target) in enumerate(train_loader):
    output = model(data)
    optimizer.zero_grad()
    loss = loss_func(output, target)
    total_loss += loss.item()
    loss.backward()
    optimizer.step()
    with torch.no_grad(): 
      pred = torch.argmax(output, dim = 1)
      total_acc += pred.eq(target.view_as(pred)).sum().item()
  avg_acc = total_acc / len(train_loader.dataset)
  avg_loss = total_loss / len(train_loader.dataset)
  return avg_loss, avg_acc

def test(model, device, test_loader, loss_func):
    model.eval()
    total_loss = 0
    total_acc = 0
    ans_list = []
    pred_list = []
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            total_loss += loss_func(output, target).item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            total_acc += pred.eq(target.view_as(pred)).sum().item()
            ans_list += target.tolist()
            pred_list += pred.tolist()
    avg_acc = total_acc / len(test_loader.dataset)
    avg_loss = total_loss / len(test_loader.dataset)
    return avg_loss, avg_acc

In [46]:
from enum import Enum,Flag, auto
class FixedMode(Flag):
    PerTen = auto()
    PerCh = auto()
    Pow2 = auto()
    Float = auto()
    Affine = auto()
    Symmetric = auto()
    PerTenPow2 = PerTen | Pow2 | Symmetric
    PerChPow2 = PerCh | Pow2 | Symmetric
    PerTenFloat = PerTen | Float | Symmetric
    PerChFloat = PerCh | Float | Symmetric
    PerTenPow2Affine = PerTen | Pow2 | Affine
    PerChPow2Affine = PerCh | Pow2 | Affine
    PerTenFloatAffine = PerTen | Float | Affine
    PerChFloatAffine = PerCh | Float | Affine
model_qat = mynet.QuantNet()
model_qat.eval()
act_bit = 3
weight_bit = 3
weight_qmin = -(2**(weight_bit - 1))
weight_qmax = (2**(weight_bit - 1)-1)
act_qmin = 0
act_qmax = 2**act_bit - 1
fixedmode = FixedMode.PerChPow2Affine
act_qscheme = torch.per_channel_affine


# weight_qscheme
if (fixedmode & FixedMode.PerCh) and (fixedmode & FixedMode.Affine):
    weight_qscheme=torch.per_channel_affine
    print("PerCh Affine")
elif (fixedmode & FixedMode.PerCh) and (fixedmode & FixedMode.Symmetric):
    weight_qscheme=torch.per_channel_symmetric
    print("PerCh Symmetric")
elif (fixedmode & FixedMode.PerTen) and (fixedmode & FixedMode.Affine):
    weight_qscheme=torch.per_tensor_affine
    print("PerTen Affine")
elif (fixedmode & FixedMode.PerTen) and (fixedmode & FixedMode.Symmetric):
    weight_qscheme=torch.per_tensor_symmetric
    print("PerTen Symmetric")

fake_quantize = myfake_quantize.ApFusedMovingAvgObsFakeQuantize
if (fixedmode & FixedMode.PerCh) and (fixedmode & FixedMode.Float):
    act_observer = tq.MovingAverageMinMaxObserver
    weight_observer = tq.MovingAveragePerChannelMinMaxObserver
    
elif (fixedmode & FixedMode.PerTen) and (fixedmode & FixedMode.Float):
    act_observer = tq.MovingAverageMinMaxObserver
    weight_observer = tq.MovingAverageMinMaxObserver
    
elif (fixedmode & FixedMode.PerCh) and (fixedmode & FixedMode.Pow2):
    act_observer = myobserver.Pow2MovingAverageMinMaxObserver
    weight_observer = myobserver.Pow2MovingAveragePerChannelMinMaxObserver
    
elif (fixedmode & FixedMode.PerTen) and (fixedmode & FixedMode.Pow2):
    act_observer = myobserver.Pow2MovingAverageMinMaxObserver
    weight_observer = myobserver.Pow2MovingAverageMinMaxObserver
model_qat.qconfig = torch.quantization.QConfig(activation=fake_quantize.with_args(
                                            observer=act_observer.with_args(quant_min=0,quant_max=act_qmax),dtype=torch.quint8,quant_min=0,quant_max=act_qmax),
                      weight=fake_quantize.with_args(observer = weight_observer.with_args(quant_min=weight_qmin,quant_max=weight_qmax,qscheme = weight_qscheme)
                                                     ,dtype=torch.qint8,quant_min=weight_qmin,quant_max=weight_qmax))
model_qat.fc2.qconfig = torch.quantization.QConfig(activation=fake_quantize.with_args(
                                            observer=act_observer.with_args(quant_min=0,quant_max=255),dtype=torch.quint8,quant_min=0,quant_max=255),
                      weight=fake_quantize.with_args(observer = weight_observer.with_args(quant_min=weight_qmin,quant_max=weight_qmax,qscheme = weight_qscheme),dtype=torch.qint8,quant_min=weight_qmin,quant_max=weight_qmax))


torch.quantization.fuse_modules(model_qat, [['conv1', 'relu1']],inplace=True)
torch.quantization.fuse_modules(model_qat, [['conv2', 'relu2']],inplace=True)
torch.quantization.fuse_modules(model_qat, [['fc1', 'relu3']],inplace=True)
torch.quantization.prepare_qat(model_qat,inplace=True)
# model_qat.conv1.activation_post_process.activation_post_process



PerCh Affine


QuantNet(
  (quant): QuantStub(
    (activation_post_process): ApFusedMovingAvgObsFakeQuantize(
      fake_quant_enabled=tensor([1]), observer_enabled=tensor([1]), scale=tensor([1.]), zero_point=tensor([0], dtype=torch.int32), dtype=torch.quint8, quant_min=0, quant_max=7, qscheme=torch.per_tensor_affine, reduce_range=False
      (activation_post_process): Pow2MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
    )
  )
  (conv1): ConvReLU2d(
    1, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
    (weight_fake_quant): ApFusedMovingAvgObsFakeQuantize(
      fake_quant_enabled=tensor([1]), observer_enabled=tensor([1]), scale=tensor([1.]), zero_point=tensor([0], dtype=torch.int32), dtype=torch.qint8, quant_min=-4, quant_max=3, qscheme=torch.per_channel_affine, reduce_range=False
      (activation_post_process): Pow2MovingAveragePerChannelMinMaxObserver(min_val=tensor([]), max_val=tensor([]))
    )
    (activation_post_process): ApFusedMovingAvgObsFakeQuantize(
      fake_quant_

In [47]:
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())

# データの読み出し方法の定義
# 1stepの学習・テストごとに16枚ずつ画像を読みだす
trainloader = torch.utils.data.DataLoader(trainset, batch_size=16, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=16, shuffle=False)

In [48]:
epochs = 10
loss_func = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model_qat.parameters(), lr=0.0001)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# QAT takes time and one needs to train over a few epochs.
# Train and check accuracy after each epoch
for epoch in range(1,epochs+1):
    avg_loss, avg_acc = train(model_qat,device,trainloader,loss_func,optimizer)
    if epoch > 3:
        # Freeze quantizer parameters
        model_qat.apply(torch.quantization.disable_observer)
    if epoch > 2:
        # Freeze batch norm mean and variance estimates
        model_qat.apply(torch.nn.intrinsic.qat.freeze_bn_stats)

    # Check the accuracy after each epoch
    quantized_model = torch.quantization.convert(model_qat.eval(), inplace=False)
    quantized_model.eval()
    avg_val_loss, avg_val_acc = test(quantized_model,device,testloader,loss_func)
    print(f'[Epoch {epoch:3d}/{epochs:3d}]' \
          f' loss: {avg_loss:.5f}, acc: {avg_acc:.5f}' \
          f' val_loss: {avg_val_loss:.5f}, val_acc: {avg_val_acc:.5f}')

[Epoch   1/ 10] loss: 0.05358, acc: 0.76465 val_loss: 0.03317, val_acc: 0.84120
[Epoch   2/ 10] loss: 0.02153, acc: 0.89123 val_loss: 0.02689, val_acc: 0.85640
[Epoch   3/ 10] loss: 0.01705, acc: 0.91312 val_loss: 0.02582, val_acc: 0.86910
[Epoch   4/ 10] loss: 0.01465, acc: 0.92585 val_loss: 0.02738, val_acc: 0.86110
[Epoch   5/ 10] loss: 0.01263, acc: 0.93750 val_loss: 0.01806, val_acc: 0.90770
[Epoch   6/ 10] loss: 0.01087, acc: 0.94463 val_loss: 0.01576, val_acc: 0.91760
[Epoch   7/ 10] loss: 0.00992, acc: 0.95023 val_loss: 0.02234, val_acc: 0.88360
[Epoch   8/ 10] loss: 0.00925, acc: 0.95423 val_loss: 0.01631, val_acc: 0.91700
[Epoch   9/ 10] loss: 0.00864, acc: 0.95710 val_loss: 0.02429, val_acc: 0.87430
[Epoch  10/ 10] loss: 0.00833, acc: 0.95905 val_loss: 0.02425, val_acc: 0.87800


In [49]:
model_int8 = torch.quantization.convert(model_qat)

In [61]:
model_int8.conv2.weight().q_per_channel_scales()

tensor([0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.2500],
       dtype=torch.float64)