# PTQ basic experiments





In [None]:
# Basic imports
import time
from enum import Enum

import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision
import torchvision.transforms as transforms

import numpy as np
import matplotlib.pyplot as plt

## Arguments

### argument options
*   bitW : choose weight parameter bits
*   bitA : choose activation parameter bits
*   per_ch : granularity option (if per_ch is True, Channelwise Quantization else Layerwise Quantization)
*   symW : choose symmetric quantization or asymmetric quantization about weight parameter
*   symA : choose symmetric quantization or asymmetric quantization about activation map parameter


In [None]:
import easydict
args = easydict.EasyDict({
    "workers" : 4,
    "batch_size" : 128,
    "print_freq" : 10,
    "bitW" : 4,     # 1~k bit(2 ~ 8 bit recommended)
    "bitA" : 4,     # 1~k bit(2 ~ 8 bit recommended)
    "symW" : True,
    "per_ch" : False,
    "symA" : False,
})

if torch.cuda.is_available():
  device = torch.device("cuda")
  print('Training on GPU')
else:
  device = torch.device("cpu")
  print('Training on CPU')


Training on GPU


## Functions


*   Functions for accuracy
*   Prepare datasets about CIFAR10


In [None]:
# define loss function (criterion)
criterion = nn.CrossEntropyLoss().to(device)

In [None]:
class Summary(Enum):
  NONE = 0
  AVERAGE = 1
  SUM = 2
  COUNT = 3

class AverageMeter(object):
  """Computes and stores the average and current value"""
  def __init__(self, name, fmt=':f', summary_type=Summary.AVERAGE):
      self.name = name
      self.fmt = fmt
      self.summary_type = summary_type
      self.reset()

  def reset(self):
      self.val = 0
      self.avg = 0
      self.sum = 0
      self.count = 0

  def update(self, val, n=1):
      self.val = val
      self.sum += val * n
      self.count += n
      self.avg = self.sum / self.count

  def all_reduce(self):
      if torch.cuda.is_available():
          device = torch.device("cuda")
      else:
          device = torch.device("cpu")
      total = torch.tensor([self.sum, self.count], dtype=torch.float32, device=device)
      self.sum, self.count = total.tolist()
      self.avg = self.sum / self.count

  def __str__(self):
      fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
      return fmtstr.format(**self.__dict__)

  def summary(self):
      fmtstr = ''
      if self.summary_type is Summary.NONE:
          fmtstr = ''
      elif self.summary_type is Summary.AVERAGE:
          fmtstr = '{name} {avg:.3f}'
      elif self.summary_type is Summary.SUM:
          fmtstr = '{name} {sum:.3f}'
      elif self.summary_type is Summary.COUNT:
          fmtstr = '{name} {count:.3f}'
      else:
          raise ValueError('invalid summary type %r' % self.summary_type)

      return fmtstr.format(**self.__dict__)


class ProgressMeter(object):
  def __init__(self, num_batches, meters, prefix=""):
      self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
      self.meters = meters
      self.prefix = prefix

  def display(self, batch):
      entries = [self.prefix + self.batch_fmtstr.format(batch)]
      entries += [str(meter) for meter in self.meters]
      print('\t'.join(entries))

  def display_summary(self):
      entries = [" *"]
      entries += [meter.summary() for meter in self.meters]
      print(' '.join(entries))

  def _get_batch_fmtstr(self, num_batches):
      num_digits = len(str(num_batches // 1))
      fmt = '{:' + str(num_digits) + 'd}'
      return '[' + fmt + '/' + fmt.format(num_batches) + ']'

In [None]:
def make_dataloaders(args):
    train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4),
                                          transforms.RandomHorizontalFlip(),
                                          transforms.ToTensor(),
                                          transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])])

    valid_transform = transforms.Compose([transforms.ToTensor(),
                                               transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])])

    #transform_validation = transforms.Compose([transforms.ToTensor()])

    trainset = torchvision.datasets.CIFAR10(root='./train', train=True, download=True, transform=train_transform)
    testset = torchvision.datasets.CIFAR10(root='./val', train=False, download=True, transform=valid_transform)

    train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers)
    val_loader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)

    print("number of training dataset:%d"%len(trainset))
    print("number of validation dataset:%d"%len(testset))
    return train_loader, val_loader

In [None]:
def validate(val_loader, model, criterion):

  def run_validate(loader, base_progress=0):
      with torch.no_grad():
          end = time.time()
          for i, (images, target) in enumerate(loader):
              i = base_progress + i
              if torch.cuda.is_available():
                  images, target = images.cuda(), target.cuda()

              # compute output
              output = model(images)
              loss = criterion(output, target)

              # measure accuracy and record loss
              acc1, acc5 = accuracy(output, target, topk=(1, 5))
              losses.update(loss.item(), images.size(0))
              top1.update(acc1[0], images.size(0))
              top5.update(acc5[0], images.size(0))

              # measure elapsed time
              batch_time.update(time.time() - end)
              end = time.time()

              if i % args.print_freq == 0:
                  progress.display(i + 1)

  batch_time = AverageMeter('Time', ':6.3f', Summary.NONE)
  losses = AverageMeter('Loss', ':.4e', Summary.NONE)
  top1 = AverageMeter('Acc@1', ':6.2f', Summary.AVERAGE)
  top5 = AverageMeter('Acc@5', ':6.2f', Summary.AVERAGE)
  progress = ProgressMeter(
      len(val_loader),[batch_time, losses, top1, top5],prefix='Test: ')

  # switch to evaluate mode
  model.eval()

  run_validate(val_loader)

  progress.display_summary()

  return top1.avg

In [None]:
def accuracy(output, target, topk=(1,)):
  """Computes the accuracy over the k top predictions for the specified values of k"""
  with torch.no_grad():
      maxk = max(topk)
      batch_size = target.size(0)

      _, pred = output.topk(maxk, 1, True, True)
      pred = pred.t()
      correct = pred.eq(target.view(1, -1).expand_as(pred))

      res = []
      for k in topk:
          correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
          res.append(correct_k.mul_(100.0 / batch_size))
      return res

In [None]:
train_loader, val_loader = make_dataloaders(args)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./train/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:12<00:00, 13604586.60it/s]


Extracting ./train/cifar-10-python.tar.gz to ./train
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./val/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:12<00:00, 13336313.10it/s]


Extracting ./val/cifar-10-python.tar.gz to ./val
number of training dataset:50000
number of validation dataset:10000




## Functions for quantization


*   calcScaleZeroPoint : Asymmetric quantization mode
*   quantize_weight : Weight quantization mode(symmetry or asymmetry)
*   quantize_activation : Activation map quantization mode(only asymmetry)


In [None]:
def calcScaleZeroPoint(min_val, max_val, num_bits):
    qmin = 0.
    qmax = 2. ** num_bits - 1.
    scale = (max_val - min_val) / (qmax - qmin)

    zero_point = qmax - max_val / scale

    if zero_point < qmin:
        zero_point = torch.tensor([qmin], dtype=torch.float32).to(min_val.device)
    elif zero_point > qmax:
        zero_point = torch.tensor([qmax], dtype=torch.float32).to(max_val.device)

    zero_point.round_()

    return scale, zero_point

def quantize_weight(x, args):
  bits = args.bitW
  q_x = torch.empty_like(x).copy_(x)
  dq_x = torch.empty_like(x).copy_(x)
  if args.symW:
    qmin = -2**(bits-1)
    qmax = 2**(bits-1)-1
    if args.per_ch:
      for i in range(len(x)):
        scale = (torch.max(x[i,...]) - torch.min(x[i,...])) / (qmax - qmin)
        q_x[i,...] = (x[i,...]/scale).round_()
        q_x[i,...].clamp(qmin, qmax)
        # dequantize
        dq_x[i,...] = q_x[i,...]*scale
    else:
      scale = (torch.max(x) - torch.min(x)) / (qmax - qmin)
      q_x = (x/scale).round_()
      q_x.clamp(qmin, qmax)
      # dequantize
      dq_x = q_x * scale
  else:
    qmin = 0
    qmax = 2**(bits)-1
    scale, zero_point = calcScaleZeroPoint(torch.min(x),torch.max(x),bits)
    q_x = (x/scale).round_() - zero_point
    q_x.clamp(qmin, qmax)
    # dequantize
    dq_x = (q_x+ zero_point)* scale
  return dq_x, scale

def quantize_activation(x, bits):
  q_x = torch.empty_like(x).copy_(x)
  dq_x = torch.empty_like(x).copy_(x)
  qmin = 0
  qmax = 2**(bits)-1
  scale, zero_point = calcScaleZeroPoint(torch.min(x),torch.max(x),bits)
  q_x = (x/scale).round_() - zero_point
  q_x.clamp(qmin, qmax)
  # dequantize
  dq_x = (q_x+ zero_point)* scale
  return dq_x

## Layer for Quantization


*   QConv2d : Quantization for Conv2d layer weight values  
*   QLinear : Quantization for Linear layer weight values  
*   QReLU : Quantization for ReLU activation function   


In [None]:
class QConv2d(nn.Conv2d):
  def __init__(self, in_channels, out_channels, args, kernel_size, stride=1,padding=0, dilation=1, groups=1, bias=False):
    super(QConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
    self.args = args
    self.k = args.bitW

  def forward(self, input):
    if self.k != 32:
      quantized_weight, _ = quantize_weight(self.weight, self.args)
      if self.bias is not None:
          quantized_bias = self.bias
      else:
          quantized_bias = None
      return F.conv2d(input, quantized_weight, quantized_bias, self.stride, self.padding, self.dilation, self.groups)
    else:
      return F.conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)

In [None]:
class QLinear(nn.Linear):
  def __init__(self, in_features, out_features, args, bias=False):
    super(QLinear, self).__init__(in_features=in_features, out_features=out_features, bias=bias)
    self.args = args
    self.k = args.bitW

  def forward(self, input):
    if self.k != 32:
      quantized_weight, _ = quantize_weight(self.weight, self.args)

      if self.bias is not None:
          quantized_bias = self.bias
      else:
          quantized_bias = None

      return F.linear(input, quantized_weight, quantized_bias)
    else:
      return F.linear(input, self.weight, self.bias)

In [None]:
class QReLU(nn.ReLU):
  def __init__(self, bitA, inplace=False):
    super(QReLU, self).__init__(inplace=inplace)
    self.k = bitA

  def forward(self, input):
    out = F.relu(input)
    if self.k != 32:
      return quantize_activation(out, self.k)
    else:
      return out

## CIFAR10_ResNet20 model network



*   Build pretrained model on cifar10



In [None]:
import sys
import torch.nn as nn
try:
    from torch.hub import load_state_dict_from_url
except ImportError:
    from torch.utils.model_zoo import load_url as load_state_dict_from_url

from functools import partial
from typing import Dict, Type, Any, Callable, Union, List, Optional

model_urls = 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/resnet/cifar10_resnet20-4118986f.pt'

class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, inplanes, planes, args, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = QConv2d(inplanes, planes, args, kernel_size=3, stride=stride,padding=1)
        self.bn1 = nn.BatchNorm2d(planes)
        self.QReLU = QReLU(args.bitA,inplace=True)
        self.conv2 = QConv2d(planes, planes,args,kernel_size=3,padding=1)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.QReLU(out)
        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        out = self.QReLU(out)

        return out


class CifarResNet(nn.Module):

    def __init__(self, block, layers, args, num_classes=10):
        super(CifarResNet, self).__init__()
        self.inplanes = 16
        self.args = args
        self.conv1 = QConv2d(3,16, args,kernel_size=3,padding=1)
        self.bn1 = nn.BatchNorm2d(16)
        self.QReLU = QReLU(args.bitA, inplace=True)

        self.layer1 = self._make_layer(block, 16, layers[0])
        self.layer2 = self._make_layer(block, 32, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 64, layers[2], stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(64 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                QConv2d(self.inplanes, planes * block.expansion, self.args, kernel_size=1, stride = stride),
                nn.BatchNorm2d(planes * block.expansion),
            )
        layers = []
        layers.append(block(self.inplanes, planes, self.args, stride, downsample))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes,self.args))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.QReLU(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x

def _resnet(
    arch: str,
    layers: List[int],
    model_urls: Dict[str, str],
    progress: bool = True,
    pretrained: bool = True,
    **kwargs: Any
) -> CifarResNet:
    model = CifarResNet(BasicBlock, layers, **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls, progress=progress)
        model.load_state_dict(state_dict)
    return model

def cifar10_resnet20(*args, **kwargs) -> CifarResNet: pass

thismodule = sys.modules[__name__]
setattr(thismodule, 'cifar10_resnet20', partial(_resnet, arch="resnet20",layers=[3]*3, model_urls=model_urls, args= args, num_classes=10))

In [None]:
model = cifar10_resnet20()
model.to(device)

Downloading: "https://github.com/chenyaofo/pytorch-cifar-models/releases/download/resnet/cifar10_resnet20-4118986f.pt" to /root/.cache/torch/hub/checkpoints/cifar10_resnet20-4118986f.pt
100%|██████████| 1.09M/1.09M [00:01<00:00, 778kB/s]


CifarResNet(
  (conv1): QConv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (QReLU): QReLU(inplace=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): QConv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (QReLU): QReLU(inplace=True)
      (conv2): QConv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): QConv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (QReLU): QReLU(inplace=True)
      (conv2): QConv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(

In [None]:
print(args)
end = time.time()
validate(val_loader, model, criterion)
print(time.time() - end)

{'workers': 4, 'batch_size': 128, 'print_freq': 10, 'bitW': 4, 'bitA': 4, 'symW': True, 'per_ch': False, 'symA': False}
Test: [ 1/79]	Time  8.036 ( 8.036)	Loss 4.6690e-01 (4.6690e-01)	Acc@1  85.94 ( 85.94)	Acc@5  99.22 ( 99.22)
Test: [11/79]	Time  0.034 ( 0.769)	Loss 5.9015e-01 (5.9268e-01)	Acc@1  82.81 ( 83.95)	Acc@5  97.66 ( 99.08)
Test: [21/79]	Time  0.030 ( 0.424)	Loss 5.2312e-01 (5.9590e-01)	Acc@1  87.50 ( 84.97)	Acc@5  98.44 ( 98.96)
Test: [31/79]	Time  0.040 ( 0.302)	Loss 2.7344e-01 (5.8534e-01)	Acc@1  93.75 ( 85.28)	Acc@5  99.22 ( 98.92)
Test: [41/79]	Time  0.055 ( 0.238)	Loss 8.8165e-01 (5.8599e-01)	Acc@1  80.47 ( 85.23)	Acc@5  98.44 ( 98.93)
Test: [51/79]	Time  0.076 ( 0.200)	Loss 5.9611e-01 (5.8055e-01)	Acc@1  85.16 ( 85.39)	Acc@5 100.00 ( 99.05)
Test: [61/79]	Time  0.037 ( 0.174)	Loss 8.3705e-01 (5.8198e-01)	Acc@1  80.47 ( 85.23)	Acc@5  99.22 ( 99.08)
Test: [71/79]	Time  0.060 ( 0.157)	Loss 6.9982e-01 (5.9056e-01)	Acc@1  84.38 ( 85.13)	Acc@5  99.22 ( 99.06)
 *   Acc@1 85.37