## Section 1: First we try PyTorch Quantization.
### This is the best available option as we will see, in terms of accuracy preservation.
### However, the FBGEMM backend is not available of FPGA and other embedded systems. In the next section we will implement Quantization from scratch

In [1]:
# import required packages
import os
import time
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.prune as prune
import torchvision
from torchvision import transforms, datasets
import numpy as np
import warnings
from resnet_18 import *
warnings.filterwarnings('ignore')

Inside main  forward shape:  torch.Size([1, 3, 64, 64])
Inside main  forward conv1 shape:  torch.Size([1, 64, 32, 32])
Inside main  forward pool shape:  torch.Size([1, 64, 16, 16])
Inside layer shape:  torch.Size([1, 64, 16, 16])
Inside layer conv1 shape:  torch.Size([1, 64, 16, 16])
Inside layer conv2 shape:  torch.Size([1, 64, 16, 16])
Inside layer out shape:  torch.Size([1, 64, 16, 16])
Inside layer shape:  torch.Size([1, 64, 16, 16])
Inside layer conv1 shape:  torch.Size([1, 64, 16, 16])
Inside layer conv2 shape:  torch.Size([1, 64, 16, 16])
Inside layer out shape:  torch.Size([1, 64, 16, 16])
Inside main  forward layer 1 shape:  torch.Size([1, 64, 16, 16])
Inside layer shape:  torch.Size([1, 64, 16, 16])
Inside layer conv1 shape:  torch.Size([1, 64, 16, 16])
Inside layer conv2 shape:  torch.Size([1, 64, 16, 16])
Inside layer out shape:  torch.Size([1, 128, 8, 8])
Inside layer shape:  torch.Size([1, 128, 8, 8])
Inside layer conv1 shape:  torch.Size([1, 128, 8, 8])
Inside layer conv

In [2]:
# Load and display samples from Tiny ImageNet  dataset
torch.manual_seed(0)
np.random.seed(0)
torch.use_deterministic_algorithms(True)
directory = "../../dataset/tiny-imagenet-200/"
num_classes = 200
# the magic normalization parameters come from the example
transform_mean = np.array([ 0.485, 0.456, 0.406 ])
transform_std = np.array([ 0.229, 0.224, 0.225 ])

val_transform = transforms.Compose([
    #transforms.Resize(256),
    #transforms.CenterCrop(224),
    transforms.Resize(74),
    transforms.CenterCrop(64),
    transforms.ToTensor(),
    transforms.Normalize(mean = transform_mean, std = transform_std),
])


##### Related to trainset , need only for label ids ##############
traindir = os.path.join(directory, "train")
bs = 1
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(64),
    transforms.AutoAugment(transforms.AutoAugmentPolicy.IMAGENET),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean = transform_mean, std = transform_std),
])
train = datasets.ImageFolder(traindir, train_transform)
train_loader = torch.utils.data.DataLoader(train, batch_size=bs, shuffle=True)
assert num_classes == len(train_loader.dataset.classes)
small_labels = {}
with open(os.path.join(directory, "words.txt"), "r") as dictionary_file:
    line = dictionary_file.readline()
    while line:
        label_id, label = line.strip().split("\t")
        small_labels[label_id] = label
        line = dictionary_file.readline()
labels = {}
label_ids = {}
for label_index, label_id in enumerate(train_loader.dataset.classes):
    label = small_labels[label_id]
    labels[label_index] = label
    label_ids[label_id] = label_index
############# All these just to get the label ids ############################

valdir = os.path.join(directory, "val")

val = datasets.ImageFolder(valdir, val_transform)

val_loader = torch.utils.data.DataLoader(val, batch_size=bs, shuffle=True)

small_labels = {}
with open(os.path.join(directory, "words.txt"), "r") as dictionary_file:
    line = dictionary_file.readline()
    while line:
        label_id, label = line.strip().split("\t")
        small_labels[label_id] = label
        line = dictionary_file.readline()


val_label_map = {}
with open(os.path.join(directory, "val/val_annotations.txt"), "r") as val_label_file:
    line = val_label_file.readline()
    while line:
        file_name, label_id, _, _, _, _ = line.strip().split("\t")
        val_label_map[file_name] = label_id
        line = val_label_file.readline()


for i in range(len(val_loader.dataset.imgs)):
    file_path = val_loader.dataset.imgs[i][0]

    file_name = os.path.basename(file_path)
    label_id = val_label_map[file_name]

    val_loader.dataset.imgs[i] = (file_path, label_ids[label_id])

# images, labels = next(iter(val_loader))

# print(images.shape)
# print(labels.shape)

# figure = plt.figure(figsize = (24,16))
# num_of_images = 8
# for index in range(1, num_of_images + 1):
#     plt.subplot(6, 10, index)
#     plt.axis('off')
#     plt.imshow(images[index].permute(1,2,0).numpy().squeeze().astype('uint8'), cmap='summer')


In [3]:
# ResNet Model - Defined in resnet_18_sym.py 
pretrained_model = "./tinyimg_resnet_std.pt"
net = resnet18() 
sd = torch.load(pretrained_model, map_location=torch.device('cpu'))
net.load_state_dict(sd['net'])

<All keys matched successfully>

In [7]:
# Print the model for our help
print(net)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [4]:
# Helper function for getting the model size
def print_size_of_model(model):
    """ Print the size of the model.
    
    Args:
        model: model whose size needs to be determined

    """
    torch.save(model.state_dict(), "temp.p")
    print('Size of the model(MB):', round(os.path.getsize('temp.p')/(1024*1024),3))
    os.remove('temp.p')

In [None]:
print_size_of_model(net)

In [5]:
# Main accuracy testing function

def test(model, device, test_loader, train_loader, batch_size, quantize=False, fbgemm=False):
    model.to(device)
    model.eval()
    
    # Testing with qauntization if quantize=True
    if quantize:
#         modules_to_fuse = [['conv1', 'bn1'],
#                    ['layer1.0.conv1', 'layer1.0.bn1'],
#                    ['layer1.0.conv2', 'layer1.0.bn2'],
#                    ['layer1.1.conv1', 'layer1.1.bn1'],
#                    ['layer1.1.conv2', 'layer1.1.bn2'],
#                    ['layer2.0.conv1', 'layer2.0.bn1'],
#                    ['layer2.0.conv2', 'layer2.0.bn2'],
#                    ['layer2.0.downsample.0', 'layer2.0.downsample.1'],
#                    ['layer2.1.conv1', 'layer2.1.bn1'],
#                    ['layer2.1.conv2', 'layer2.1.bn2'],
#                    ['layer3.0.conv1', 'layer3.0.bn1'],
#                    ['layer3.0.conv2', 'layer3.0.bn2'],
#                    ['layer3.0.downsample.0', 'layer3.0.downsample.1'],
#                    ['layer3.1.conv1', 'layer3.1.bn1'],
#                    ['layer3.1.conv2', 'layer3.1.bn2'],
#                    ['layer4.0.conv1', 'layer4.0.bn1'],
#                    ['layer4.0.conv2', 'layer4.0.bn2'],
#                    ['layer4.0.downsample.0', 'layer4.0.downsample.1'],
#                    ['layer4.1.conv1', 'layer4.1.bn1'],
#                    ['layer4.1.conv2', 'layer4.1.bn2']]
#         model = torch.quantization.fuse_modules(model, modules_to_fuse)
        if fbgemm:
            model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
        else:
            model.qconfig = torch.quantization.default_qconfig
        torch.quantization.prepare(model, inplace=True)
        model.eval()
        with torch.no_grad():
            for data, target in train_loader:
                model.forward(data)
            torch.quantization.convert(model, inplace=True)
            print("======= Quantization Done =====")


    #print(model)
    
    correct = 0
    total = 0
    with torch.no_grad():
        st = time.time()
        for data in test_loader:
            X, y = data
            #st = time.time()
            output = model.forward(X)
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                #if True:
                    correct += 1
            total += batch_size
            break
        et = time.time()    
    acc = round(correct/total, 4)
    print("========================================= PERFORMANCE =============================================")
    print_size_of_model(model)
    print("PyTorch optimized model test accuracy :{}% ".format(100*round(correct/total, 4)))
    print('Elapsed time = {:0.4f} milliseconds'.format((et - st) * 1000))
    print("====================================================================================================")

In [6]:
# Baseline performance - unquantized model
device = 'cpu'
test(model=net, device=device, test_loader=val_loader, train_loader=train_loader, batch_size=bs)

Inside main  forward shape:  torch.Size([1, 3, 64, 64])
Inside main  forward conv1 shape:  torch.Size([1, 64, 32, 32])
Inside main  forward pool shape:  torch.Size([1, 64, 16, 16])
Inside layer shape:  torch.Size([1, 64, 16, 16])
Inside layer conv1 shape:  torch.Size([1, 64, 16, 16])
Inside layer conv2 shape:  torch.Size([1, 64, 16, 16])
Inside layer out shape:  torch.Size([1, 64, 16, 16])
Inside layer shape:  torch.Size([1, 64, 16, 16])
Inside layer conv1 shape:  torch.Size([1, 64, 16, 16])
Inside layer conv2 shape:  torch.Size([1, 64, 16, 16])
Inside layer out shape:  torch.Size([1, 64, 16, 16])
Inside main  forward layer 1 shape:  torch.Size([1, 64, 16, 16])
Inside layer shape:  torch.Size([1, 64, 16, 16])
Inside layer conv1 shape:  torch.Size([1, 64, 16, 16])
Inside layer conv2 shape:  torch.Size([1, 64, 16, 16])
Inside layer out shape:  torch.Size([1, 128, 8, 8])
Inside layer shape:  torch.Size([1, 128, 8, 8])
Inside layer conv1 shape:  torch.Size([1, 128, 8, 8])
Inside layer conv

In [None]:
# Quantization Performance
# Load the model to be quantized with Pytorch Quantization - Unfortunately this will fail
device = 'cpu'
import copy
netq = copy.deepcopy(net)
test(model=netq, device=device, test_loader=val_loader, train_loader=train_loader, batch_size=bs, quantize=True)

### Huh! PyTorch Quantization fails for ResNet - in the QuantizedCPU implementations.
### Even if it worked, it would not help for FPGA implementation. The FBGEMM backend is not available of FPGA and other embedded systems. In the next section we will implement Quantization from scratch

In [None]:
# Routines for post training quantization - with calibration from scratch    
    
# Simple implementation for FPGA
# Routines for Quantization    

# Routines for Quantization 

from collections import namedtuple
QTensor = namedtuple('QTensor', ['tensor', 'scale', 'zero_point'])
nb = 8
def calcScaleZeroPoint(min_val, max_val,num_bits=nb):
  # Calc Scale and zero point of next 
  qmin = 0.
  qmax = 2.**num_bits - 1.

  scale = (max_val - min_val) / (qmax - qmin)

  initial_zero_point = qmin - min_val / scale
  
  zero_point = 0
  if initial_zero_point < qmin:
      zero_point = qmin
  elif initial_zero_point > qmax:
      zero_point = qmax
  else:
      zero_point = initial_zero_point

  zero_point = int(zero_point)

  return scale, zero_point

def quantize_tensor(x, num_bits=nb, min_val=None, max_val=None):
    
    if not min_val and not max_val: 
      min_val, max_val = x.min(), x.max()

    qmin = 0.
    qmax = 2.**num_bits - 1.

    scale, zero_point = calcScaleZeroPoint(min_val, max_val, num_bits)
    q_x = zero_point + x / scale
    q_x.clamp_(qmin, qmax).round_()
    q_x = q_x.round().byte()
    
    return QTensor(tensor=q_x, scale=scale, zero_point=zero_point)

def dequantize_tensor(q_x):
    return q_x.scale * (q_x.tensor.float() - q_x.zero_point)

def calcScaleZeroPoint(min_val, max_val,num_bits=nb):
  # Calc Scale and zero point of next 
  qmin = 0.
  qmax = 2.**num_bits - 1.

  scale_next = (max_val - min_val) / (qmax - qmin)

  initial_zero_point = qmin - min_val / scale_next
  
  zero_point_next = 0
  if initial_zero_point < qmin:
      zero_point_next = qmin
  elif initial_zero_point > qmax:
      zero_point_next = qmax
  else:
      zero_point_next = initial_zero_point

  zero_point_next = int(zero_point_next)

  return scale_next, zero_point_next
  
def quantizeLayer(x, layer, stat, scale_x, zp_x, num_bits=nb):
  # for both conv and linear layers
  W = layer.weight.data
  #B = layer.bias.data

  # scale_x = x.scale
  # zp_x = x.zero_point
  w = quantize_tensor(layer.weight.data,num_bits) 
  #b = quantize_tensor(layer.bias.data,num_bits)

  layer.weight.data = w.tensor.float()
  #layer.bias.data = b.tensor.float()

  ####################################################################
  # This is Quantisation !!!!!!!!!!!!!!!!!!!!!!!!!!!!!

  scale_w = w.scale
  zp_w = w.zero_point
  
  #scale_b = b.scale
  #zp_b = b.zero_point
  

  scale_next, zero_point_next = calcScaleZeroPoint(min_val=stat['min'], max_val=stat['max'])

  # Perparing input by shifting
  X = x.float() - zp_x
  layer.weight.data = (scale_x * scale_w/scale_next)*(layer.weight.data - zp_w)
  #layer.bias.data = (scale_b/scale_next)*(layer.bias.data + zp_b)

  # All int

  x = (layer(X)/ scale_next) + zero_point_next 
    
  # Reset
  layer.weight.data = W
  #layer.bias.data = B
  
  return x, scale_next, zero_point_next


def quantForward(model, x, stats):
    
  # Quantise before inputting into incoming layers
  x = quantize_tensor(x, min_val=stats['conv1']['min'], max_val=stats['conv1']['max'])

  x, scale_next, zero_point_next = quantizeLayer(x.tensor, model.conv1, stats['l10c1'], x.scale, x.zero_point)
  x = model.bn1(x)
  x = model.relu(x)

  x, scale_next, zero_point_next = quantizeLayer(x, model.layer1[0].conv1, stats['l10c2'], scale_next, zero_point_next)
  x = model.layer1[0].bn1(x)
  x = model.layer1[0].relu(x)
  x, scale_next, zero_point_next = quantizeLayer(x, model.layer1[0].conv2, stats['l11c1'], scale_next, zero_point_next)
  x = model.layer1[0].bn2(x)
  x, scale_next, zero_point_next = quantizeLayer(x, model.layer1[1].conv1, stats['l11c2'], scale_next, zero_point_next)
  x = model.layer1[1].bn1(x)
  x = model.layer1[1].relu(x)
  x, scale_next, zero_point_next = quantizeLayer(x, model.layer1[1].conv2, stats['l20c1'], scale_next, zero_point_next)
  x = model.layer1[1].bn2(x)

  x, scale_next, zero_point_next = quantizeLayer(x, model.layer2[0].conv1, stats['l20c2'], scale_next, zero_point_next)
  x = model.layer2[0].bn1(x)
  x = model.layer2[0].relu(x)
  x, scale_next, zero_point_next = quantizeLayer(x, model.layer2[0].conv2, stats['l21c1'], scale_next, zero_point_next)
  x = model.layer2[0].bn2(x)
  #x = model.layer2[0].downsample[0](x)
  #x = model.layer2[0].downsample[1](x)
  x, scale_next, zero_point_next = quantizeLayer(x, model.layer2[1].conv1, stats['l21c2'], scale_next, zero_point_next)
  x = model.layer2[1].bn1(x)
  x = model.layer2[1].relu(x)
  x, scale_next, zero_point_next = quantizeLayer(x, model.layer2[1].conv2, stats['l30c1'], scale_next, zero_point_next)
  x = model.layer2[1].bn2(x)

  x, scale_next, zero_point_next = quantizeLayer(x, model.layer3[0].conv1, stats['l30c2'], scale_next, zero_point_next)
  x = model.layer3[0].bn1(x)
  x = model.layer3[0].relu(x)
  x, scale_next, zero_point_next = quantizeLayer(x, model.layer3[0].conv2, stats['l31c1'], scale_next, zero_point_next)
  x = model.layer3[0].bn2(x)
  #x = model.layer3[0].downsample[0](x)
  #x = model.layer3[0].downsample[1](x)
  x, scale_next, zero_point_next = quantizeLayer(x, model.layer3[1].conv1, stats['l31c2'], scale_next, zero_point_next)
  x = model.layer3[1].bn1(x)
  x = model.layer3[1].relu(x)
  x, scale_next, zero_point_next = quantizeLayer(x, model.layer3[1].conv2, stats['l40c1'], scale_next, zero_point_next)
  x = model.layer3[1].bn2(x)  

  x, scale_next, zero_point_next = quantizeLayer(x, model.layer4[0].conv1, stats['l40c2'], scale_next, zero_point_next)
  x = model.layer4[0].bn1(x)
  x = model.layer4[0].relu(x)
  x, scale_next, zero_point_next = quantizeLayer(x, model.layer4[0].conv2, stats['l41c1'], scale_next, zero_point_next)
  x = model.layer4[0].bn2(x)
  #x = model.layer4[0].downsample[0](x)
  #x = model.layer4[0].downsample[1](x)
  x, scale_next, zero_point_next = quantizeLayer(x, model.layer4[1].conv1, stats['l41c2'], scale_next, zero_point_next)
  x = model.layer4[1].bn1(x)
  x = model.layer4[1].relu(x)
  x, scale_next, zero_point_next = quantizeLayer(x, model.layer4[1].conv2, stats['fc'], scale_next, zero_point_next)
  x = model.layer4[1].bn2(x)  

  x = x.view(-1, 512)   
  
  # Back to dequant for final layer
  x = dequantize_tensor(QTensor(tensor=x, scale=scale_next, zero_point=zero_point_next))
   
  x = model.fc(x)

  return x


# Get Min and max of x tensor, and stores it
def updateStats(x, stats, key):
  max_val, _ = torch.max(x, dim=1)
  min_val, _ = torch.min(x, dim=1)
  
  
  if key not in stats:
    stats[key] = {"max": max_val.sum(), "min": min_val.sum(), "total": 1}
  else:
    stats[key]['max'] += max_val.sum().item()
    stats[key]['min'] += min_val.sum().item()
    stats[key]['total'] += 1
  
  return stats

# Reworked Forward Pass to access activation Stats through updateStats function
def gatherActivationStats(model, x, stats):

  stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'conv1')
  x = model.conv1(x)
  x = model.bn1(x)
  x = model.relu(x)

  stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'l10c1')
  x = model.layer1[0].conv1(x)
  x = model.layer1[0].bn1(x)
  x = model.layer1[0].relu(x)
  stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'l10c2')
  x = model.layer1[0].conv2(x)
  x = model.layer1[0].bn2(x)
  stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'l11c1')
  x = model.layer1[1].conv1(x)
  x = model.layer1[1].bn1(x)
  x = model.layer1[1].relu(x)
  stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'l11c2')
  x = model.layer1[1].conv2(x)
  x = model.layer1[1].bn2(x)
  stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'l20c1')
  x = model.layer2[0].conv1(x)
  x = model.layer2[0].bn1(x)
  x = model.layer2[0].relu(x)
  stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'l20c2')
  x = model.layer2[0].conv2(x)
  x = model.layer2[0].bn2(x)
  #x = model.layer2[0].downsample[0](x)
  #x = model.layer2[0].downsample[1](x)  
  stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'l21c1')
  x = model.layer2[1].conv1(x)  
  x = model.layer2[1].bn1(x)
  x = model.layer2[1].relu(x)
  stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'l21c2')
  x = model.layer2[1].conv2(x)  
  x = model.layer2[1].bn2(x)
  stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'l30c1')
  x = model.layer3[0].conv1(x)
  x = model.layer3[0].bn1(x)
  x = model.layer3[0].relu(x)
  stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'l30c2')
  x = model.layer3[0].conv2(x)
  x = model.layer3[0].bn2(x)
  #x = model.layer3[0].downsample[0](x)
  #x = model.layer3[0].downsample[1](x) 
  stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'l31c1')
  x = model.layer3[1].conv1(x)
  x = model.layer3[1].bn1(x)
  x = model.layer3[1].relu(x)
  stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'l31c2')
  x = model.layer3[1].conv2(x)   
  x = model.layer3[1].bn2(x) 
  stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'l40c1')
  x = model.layer4[0].conv1(x)
  x = model.layer4[0].bn1(x)
  x = model.layer4[0].relu(x)
  stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'l40c2')
  x = model.layer4[0].conv2(x)
  x = model.layer4[0].bn2(x)
  #x = model.layer4[0].downsample[0](x)
  #x = model.layer4[0].downsample[1](x)
  stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'l41c1')
  x = model.layer4[1].conv1(x)  
  x = model.layer4[1].bn1(x)
  x = model.layer4[1].relu(x)
  stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'l41c2')
  x = model.layer4[1].conv2(x) 
  x = model.layer4[1].bn2(x) 
  x = x.view(-1, 512) 
  
  stats = updateStats(x, stats, 'fc')

  x = model.fc(x)

  return stats

# Entry function to get stats of all functions.
def gatherStats(model, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    stats = {}
    with torch.no_grad():
        for data, target in test_loader:
            stats = gatherActivationStats(model, data, stats)
    
    final_stats = {}
    for key, value in stats.items():
      final_stats[key] = { "max" : value["max"] / value["total"], "min" : value["min"] / value["total"] }
    return final_stats

# Routines for performance testing

def test(model, device, test_loader, train_loader, batch_size, quantize=False, fbgemm=False, stats=None):
    model.to(device)
    model.eval()
    
    #print(model)
    
    correct = 0
    total = 0
    with torch.no_grad():
        st = time.time()
        for data in test_loader:
            X, y = data
            #st = time.time()
            # Testing with qauntization if quantize=True
            if quantize:
                output = quantForward(model, X, stats)
            else:    
                output = model.forward(X)
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                #if True:
                    correct += 1
            total += batch_size
        et = time.time()    
    acc = round(correct/total, 4)
    print("========================================= PERFORMANCE =============================================")
    print_size_of_model(model)
    print("PyTorch optimized model test accuracy :{}% ".format(100*round(correct/total, 2)))
    print('Elapsed time = {:0.4f} milliseconds'.format((et - st) * 1000))
    print("====================================================================================================")

In [None]:
# Quantized model performance
import copy
netqq = copy.deepcopy(net)

# one time stats gathering - we will keep this stored for CIFAR for the FPGA implementation
stats = gatherStats(netqq, train_loader)
#stats = gatherStats(netqq, val_loader)
print(stats)

In [None]:
# Quantized Inference
test(model=netqq, device=device, test_loader=val_loader, train_loader=train_loader, batch_size=bs, quantize=True, stats=stats)

### What 1% !!
### We must do a QAT to mitigate this

In [None]:
# Some more definitions required for the inference on the QAT model that we have trained offline
from collections import namedtuple

QTensor = namedtuple('QTensor', ['tensor', 'scale', 'zero_point'])

def calcScaleZeroPoint(min_val, max_val,num_bits=8):
  # Calc Scale and zero point of next 
  qmin = 0.
  qmax = 2.**num_bits - 1.

  scale = (max_val - min_val) / (qmax - qmin)

  initial_zero_point = qmin - min_val / scale
  
  zero_point = 0
  if initial_zero_point < qmin:
      zero_point = qmin
  elif initial_zero_point > qmax:
      zero_point = qmax
  else:
      zero_point = initial_zero_point

  zero_point = int(zero_point)

  return scale, zero_point

def calcScaleZeroPointSym(min_val, max_val,num_bits=8):
  
  # Calc Scale 
  max_val = max(abs(min_val), abs(max_val))
  qmin = 0.
  qmax = 2.**(num_bits-1) - 1.

  scale = max_val / qmax

  return scale, 0

def quantize_tensor(x, num_bits=8, min_val=None, max_val=None):
    
    if not min_val and not max_val: 
      min_val, max_val = x.min(), x.max()

    qmin = 0.
    qmax = 2.**num_bits - 1.

    scale, zero_point = calcScaleZeroPoint(min_val, max_val, num_bits)
    q_x = zero_point + x / scale
    q_x.clamp_(qmin, qmax).round_()
    q_x = q_x.round().byte()
    
    return QTensor(tensor=q_x, scale=scale, zero_point=zero_point)

def dequantize_tensor(q_x):
    return q_x.scale * (q_x.tensor.float() - q_x.zero_point)

def quantize_tensor_sym(x, num_bits=8, min_val=None, max_val=None):
    
    if not min_val and not max_val: 
      min_val, max_val = x.min(), x.max()

    max_val = max(abs(min_val), abs(max_val))
    qmin = 0.
    qmax = 2.**(num_bits-1) - 1.

    scale = max_val / qmax   

    q_x = x/scale

    q_x.clamp_(-qmax, qmax).round_()
    q_x = q_x.round()
    return QTensor(tensor=q_x, scale=scale, zero_point=0)

def dequantize_tensor_sym(q_x):
    return q_x.scale * (q_x.tensor.float())

def quantizeLayer(x, layer, stat, scale_x, zp_x, vis=False, axs=None, X=None, y=None, sym=False, num_bits=8):
  # for both conv and linear layers

  # cache old values
  W = layer.weight.data
  B = layer.bias.data

  # WEIGHTS SIMULATED QUANTISED

  # quantise weights, activations are already quantised
  if sym:
    w = quantize_tensor_sym(layer.weight.data,num_bits=num_bits) 
    b = quantize_tensor_sym(layer.bias.data,num_bits=num_bits)
  else:
    w = quantize_tensor(layer.weight.data, num_bits=num_bits) 
    b = quantize_tensor(layer.bias.data, num_bits=num_bits)

  layer.weight.data = w.tensor.float()
  layer.bias.data = b.tensor.float()

  ## END WEIGHTS QUANTISED SIMULATION


  if vis:
    axs[X,y].set_xlabel("Visualising weights of layer: ")
    visualise(layer.weight.data, axs[X,y])

  # QUANTISED OP, USES SCALE AND ZERO POINT TO DO LAYER FORWARD PASS. (How does backprop change here ?)
  # This is Quantisation Arithmetic
  scale_w = w.scale
  zp_w = w.zero_point
  scale_b = b.scale
  zp_b = b.zero_point
  
  if sym:
    scale_next, zero_point_next = calcScaleZeroPointSym(min_val=stat['min'], max_val=stat['max'])
  else:
    scale_next, zero_point_next = calcScaleZeroPoint(min_val=stat['min'], max_val=stat['max'])

  # Preparing input by saturating range to num_bits range.
  if sym:
    X = x.float()
    layer.weight.data = ((scale_x * scale_w) / scale_next)*(layer.weight.data)
    layer.bias.data = (scale_b/scale_next)*(layer.bias.data)
  else:
    X = x.float() - zp_x
    layer.weight.data = ((scale_x * scale_w) / scale_next)*(layer.weight.data - zp_w)
    layer.bias.data = (scale_b/scale_next)*(layer.bias.data + zp_b)

  # All int computation
  if sym:  
    x = (layer(X)) 
  else:
    x = (layer(X)) + zero_point_next 
  
  # cast to int
  x.round_()

  # Perform relu too
  x = F.relu(x)

  # Reset weights for next forward pass
  layer.weight.data = W
  layer.bias.data = B
  
  return x, scale_next, zero_point_next

# Get Min and max of x tensor, and stores it
def updateStats(x, stats, key):
  max_val, _ = torch.max(x, dim=1)
  min_val, _ = torch.min(x, dim=1)

  # add ema calculation

  if key not in stats:
    stats[key] = {'max': max_val.sum(), 'min': min_val.sum(), 'total': 1}
  else:
    stats[key]['max'] += max_val.sum().item()
    stats[key]['min'] += min_val.sum().item()
    if 'total' in stats[key]:
        stats[key]['total'] += 1
    else:
        stats[key]['total'] = 1
  
  weighting = 2.0 / (stats[key]['total']) + 1

  if 'ema_min' in stats[key]:
    stats[key]['ema_min'] = weighting*(min_val.mean().item()) + (1- weighting) * stats[key]['ema_min']
  else:
    stats[key]['ema_min'] = weighting*(min_val.mean().item())

  if 'ema_max' in stats[key]:
    stats[key]['ema_max'] = weighting*(max_val.mean().item()) + (1- weighting) * stats[key]['ema_max']
  else: 
    stats[key]['ema_max'] = weighting*(max_val.mean().item())

  stats[key]['min_val'] = stats[key]['min']/ stats[key]['total']
  stats[key]['max_val'] = stats[key]['max']/ stats[key]['total']
  
  return stats

# Reworked Forward Pass to access activation Stats through updateStats function
def gatherActivationStats(model, x, stats):
  stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'conv1')
  x = model.conv1(x)
  x = model.bn1(x)
  x = model.relu(x)

  stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'l10c1')
  x = model.layer1[0].conv1(x)
  x = model.layer1[0].bn1(x)
  x = model.layer1[0].relu(x)
  stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'l10c2')
  x = model.layer1[0].conv2(x)
  x = model.layer1[0].bn2(x)
  stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'l11c1')
  x = model.layer1[1].conv1(x)
  x = model.layer1[1].bn1(x)
  x = model.layer1[1].relu(x)
  stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'l11c2')
  x = model.layer1[1].conv2(x)
  x = model.layer1[1].bn2(x)
  stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'l20c1')
  x = model.layer2[0].conv1(x)
  x = model.layer2[0].bn1(x)
  x = model.layer2[0].relu(x)
  stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'l20c2')
  x = model.layer2[0].conv2(x)
  x = model.layer2[0].bn2(x)
  #x = model.layer2[0].downsample[0](x)
  #x = model.layer2[0].downsample[1](x)  
  stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'l21c1')
  x = model.layer2[1].conv1(x)  
  x = model.layer2[1].bn1(x)
  x = model.layer2[1].relu(x)
  stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'l21c2')
  x = model.layer2[1].conv2(x)  
  x = model.layer2[1].bn2(x)
  stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'l30c1')
  x = model.layer3[0].conv1(x)
  x = model.layer3[0].bn1(x)
  x = model.layer3[0].relu(x)
  stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'l30c2')
  x = model.layer3[0].conv2(x)
  x = model.layer3[0].bn2(x)
  #x = model.layer3[0].downsample[0](x)
  #x = model.layer3[0].downsample[1](x) 
  stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'l31c1')
  x = model.layer3[1].conv1(x)
  x = model.layer3[1].bn1(x)
  x = model.layer3[1].relu(x)
  stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'l31c2')
  x = model.layer3[1].conv2(x)   
  x = model.layer3[1].bn2(x) 
  stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'l40c1')
  x = model.layer4[0].conv1(x)
  x = model.layer4[0].bn1(x)
  x = model.layer4[0].relu(x)
  stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'l40c2')
  x = model.layer4[0].conv2(x)
  x = model.layer4[0].bn2(x)
  #x = model.layer4[0].downsample[0](x)
  #x = model.layer4[0].downsample[1](x)
  stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'l41c1')
  x = model.layer4[1].conv1(x)  
  x = model.layer4[1].bn1(x)
  x = model.layer4[1].relu(x)
  stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'l41c2')
  x = model.layer4[1].conv2(x) 
  x = model.layer4[1].bn2(x) 
  x = x.view(-1, 512) 
  
  stats = updateStats(x, stats, 'fc')

  x = model.fc(x)


  return stats

# Entry function to get stats of all functions.
def gatherStats(model, test_loader):
    device = 'cpu'
    
    model.eval()
    test_loss = 0
    correct = 0
    stats = {}
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            stats = gatherActivationStats(model, data, stats)
    
    final_stats = {}
    for key, value in stats.items():
      final_stats[key] = { "max" : value["max"] / value["total"], "min" : value["min"] / value["total"], "ema_min": value["ema_min"], "ema_max": value["ema_max"] }
    return final_stats

def quantForward(model, x, stats, vis=False, axs=None, sym=False, num_bits=8):
  X = 0
  y = 0
  # Quantise before inputting into incoming layers
  if sym:
    x = quantize_tensor_sym(x, min_val=stats['conv1']['min'], max_val=stats['conv1']['max'], num_bits=num_bits)
  else:
    x = quantize_tensor(x, min_val=stats['conv1']['min'], max_val=stats['conv1']['max'], num_bits=num_bits)

  x, scale_next, zero_point_next = quantizeLayer(x.tensor, model.conv1, stats['l10c1'], x.scale, x.zero_point)
  x = model.bn1(x)
  x = model.relu(x)

  x, scale_next, zero_point_next = quantizeLayer(x, model.layer1[0].conv1, stats['l10c2'], scale_next, zero_point_next)
  x = model.layer1[0].bn1(x)
  x = model.layer1[0].relu(x)
  x, scale_next, zero_point_next = quantizeLayer(x, model.layer1[0].conv2, stats['l11c1'], scale_next, zero_point_next)
  x = model.layer1[0].bn2(x)
  x, scale_next, zero_point_next = quantizeLayer(x, model.layer1[1].conv1, stats['l11c2'], scale_next, zero_point_next)
  x = model.layer1[1].bn1(x)
  x = model.layer1[1].relu(x)
  x, scale_next, zero_point_next = quantizeLayer(x, model.layer1[1].conv2, stats['l20c1'], scale_next, zero_point_next)
  x = model.layer1[1].bn2(x)

  x, scale_next, zero_point_next = quantizeLayer(x, model.layer2[0].conv1, stats['l20c2'], scale_next, zero_point_next)
  x = model.layer2[0].bn1(x)
  x = model.layer2[0].relu(x)
  x, scale_next, zero_point_next = quantizeLayer(x, model.layer2[0].conv2, stats['l21c1'], scale_next, zero_point_next)
  x = model.layer2[0].bn2(x)
  #x = model.layer2[0].downsample[0](x)
  #x = model.layer2[0].downsample[1](x)
  x, scale_next, zero_point_next = quantizeLayer(x, model.layer2[1].conv1, stats['l21c2'], scale_next, zero_point_next)
  x = model.layer2[1].bn1(x)
  x = model.layer2[1].relu(x)
  x, scale_next, zero_point_next = quantizeLayer(x, model.layer2[1].conv2, stats['l30c1'], scale_next, zero_point_next)
  x = model.layer2[1].bn2(x)

  x, scale_next, zero_point_next = quantizeLayer(x, model.layer3[0].conv1, stats['l30c2'], scale_next, zero_point_next)
  x = model.layer3[0].bn1(x)
  x = model.layer3[0].relu(x)
  x, scale_next, zero_point_next = quantizeLayer(x, model.layer3[0].conv2, stats['l31c1'], scale_next, zero_point_next)
  x = model.layer3[0].bn2(x)
  #x = model.layer3[0].downsample[0](x)
  #x = model.layer3[0].downsample[1](x)
  x, scale_next, zero_point_next = quantizeLayer(x, model.layer3[1].conv1, stats['l31c2'], scale_next, zero_point_next)
  x = model.layer3[1].bn1(x)
  x = model.layer3[1].relu(x)
  x, scale_next, zero_point_next = quantizeLayer(x, model.layer3[1].conv2, stats['l40c1'], scale_next, zero_point_next)
  x = model.layer3[1].bn2(x)  

  x, scale_next, zero_point_next = quantizeLayer(x, model.layer4[0].conv1, stats['l40c2'], scale_next, zero_point_next)
  x = model.layer4[0].bn1(x)
  x = model.layer4[0].relu(x)
  x, scale_next, zero_point_next = quantizeLayer(x, model.layer4[0].conv2, stats['l41c1'], scale_next, zero_point_next)
  x = model.layer4[0].bn2(x)
  #x = model.layer4[0].downsample[0](x)
  #x = model.layer4[0].downsample[1](x)
  x, scale_next, zero_point_next = quantizeLayer(x, model.layer4[1].conv1, stats['l41c2'], scale_next, zero_point_next)
  x = model.layer4[1].bn1(x)
  x = model.layer4[1].relu(x)
  x, scale_next, zero_point_next = quantizeLayer(x, model.layer4[1].conv2, stats['fc'], scale_next, zero_point_next)
  x = model.layer4[1].bn2(x)  

  x = x.view(-1, 512)   
  
  # Back to dequant for final layer
  x = dequantize_tensor(QTensor(tensor=x, scale=scale_next, zero_point=zero_point_next))
   
  x = model.fc(x)

  return x

import torch

class FakeQuantOp(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, num_bits=8, min_val=None, max_val=None):
        x = quantize_tensor(x,num_bits=num_bits, min_val=min_val, max_val=max_val)
        x = dequantize_tensor(x)
        return x

    @staticmethod
    def backward(ctx, grad_output):
        # straight through estimator
        return grad_output, None, None, None

def quantAwareTrainingForward(model, x, stats, vis=False, axs=None, sym=False, num_bits=8, act_quant=False):

  ######## Outer layer #######
  conv1weight = model.conv1.weight.data
  model.conv1.weight.data = FakeQuantOp.apply(model.conv1.weight.data, num_bits)
  x = model.conv1(x)
  with torch.no_grad():
    stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'conv1')
  if act_quant:
    x = FakeQuantOp.apply(x, num_bits, stats['conv1']['ema_min'], stats['conv1']['ema_max'])
  x = model.bn1(x)
  x = model.relu(x)

  ######## layer 1 #######
  conv2weight = model.layer1[0].conv1.weight.data
  model.layer1[0].conv1.weight.data = FakeQuantOp.apply(model.layer1[0].conv1.weight.data, num_bits)
  x = model.layer1[0].conv1(x)
  with torch.no_grad():
    stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'l10c1')
  if act_quant:
    x = FakeQuantOp.apply(x, num_bits, stats['l10c1']['ema_min'], stats['l10c1']['ema_max'])
  x = model.layer1[0].bn1(x)
  x = model.layer1[0].relu(x)

  conv3weight = model.layer1[0].conv2.weight.data
  model.layer1[0].conv2.weight.data = FakeQuantOp.apply(model.layer1[0].conv2.weight.data, num_bits)
  x = model.layer1[0].conv2(x)
  with torch.no_grad():
    stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'l10c2')
  if act_quant:
    x = FakeQuantOp.apply(x, num_bits, stats['l10c2']['ema_min'], stats['l10c2']['ema_max'])
  x = model.layer1[0].bn2(x)


  conv4weight = model.layer1[1].conv1.weight.data
  model.layer1[1].conv1.weight.data = FakeQuantOp.apply(model.layer1[1].conv1.weight.data, num_bits)
  x = model.layer1[1].conv1(x)
  with torch.no_grad():
    stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'l11c1')
  if act_quant:
    x = FakeQuantOp.apply(x, num_bits, stats['l11c1']['ema_min'], stats['l11c1']['ema_max'])
  x = model.layer1[1].bn1(x)
  x = model.layer1[1].relu(x)

  conv5weight = model.layer1[1].conv2.weight.data
  model.layer1[1].conv2.weight.data = FakeQuantOp.apply(model.layer1[1].conv2.weight.data, num_bits)
  x = model.layer1[1].conv2(x)
  with torch.no_grad():
    stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'l11c2')
  if act_quant:
    x = FakeQuantOp.apply(x, num_bits, stats['l11c2']['ema_min'], stats['l11c2']['ema_max'])
  x = model.layer1[1].bn2(x)

  ######## layer 2 #######
  conv6weight = model.layer2[0].conv1.weight.data
  model.layer2[0].conv1.weight.data = FakeQuantOp.apply(model.layer2[0].conv1.weight.data, num_bits)
  x = model.layer2[0].conv1(x)
  with torch.no_grad():
    stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'l20c1')
  if act_quant:
    x = FakeQuantOp.apply(x, num_bits, stats['l20c1']['ema_min'], stats['l20c1']['ema_max'])
  x = model.layer2[0].bn1(x)
  x = model.layer2[0].relu(x)

  conv7weight = model.layer2[0].conv2.weight.data
  model.layer2[0].conv2.weight.data = FakeQuantOp.apply(model.layer2[0].conv2.weight.data, num_bits)
  x = model.layer2[0].conv2(x)
  with torch.no_grad():
    stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'l20c2')
  if act_quant:
    x = FakeQuantOp.apply(x, num_bits, stats['l20c2']['ema_min'], stats['l20c2']['ema_max'])
  x = model.layer2[0].bn2(x)

  conv8weight = model.layer2[1].conv1.weight.data
  model.layer2[1].conv1.weight.data = FakeQuantOp.apply(model.layer2[1].conv1.weight.data, num_bits)
  x = model.layer2[1].conv1(x)
  with torch.no_grad():
    stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'l21c1')
  if act_quant:
    x = FakeQuantOp.apply(x, num_bits, stats['l21c1']['ema_min'], stats['l21c1']['ema_max'])
  x = model.layer2[1].bn1(x)
  x = model.layer2[1].relu(x)

  conv9weight = model.layer2[1].conv2.weight.data
  model.layer2[1].conv2.weight.data = FakeQuantOp.apply(model.layer2[1].conv2.weight.data, num_bits)
  x = model.layer2[1].conv2(x)
  with torch.no_grad():
    stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'l21c2')
  if act_quant:
    x = FakeQuantOp.apply(x, num_bits, stats['l21c2']['ema_min'], stats['l21c2']['ema_max'])
  x = model.layer2[1].bn2(x)

  ######## layer 3 #######
  conv10weight = model.layer3[0].conv1.weight.data
  model.layer3[0].conv1.weight.data = FakeQuantOp.apply(model.layer3[0].conv1.weight.data, num_bits)
  x = model.layer3[0].conv1(x)
  with torch.no_grad():
    stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'l30c1')
  if act_quant:
    x = FakeQuantOp.apply(x, num_bits, stats['l30c1']['ema_min'], stats['l30c1']['ema_max'])
  x = model.layer3[0].bn1(x)
  x = model.layer3[0].relu(x)

  conv11weight = model.layer3[0].conv2.weight.data
  model.layer3[0].conv2.weight.data = FakeQuantOp.apply(model.layer3[0].conv2.weight.data, num_bits)
  x = model.layer3[0].conv2(x)
  with torch.no_grad():
    stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'l30c2')
  if act_quant:
    x = FakeQuantOp.apply(x, num_bits, stats['l30c2']['ema_min'], stats['l30c2']['ema_max'])
  x = model.layer3[0].bn2(x)

  conv12weight = model.layer3[1].conv1.weight.data
  model.layer3[1].conv1.weight.data = FakeQuantOp.apply(model.layer3[1].conv1.weight.data, num_bits)
  x = model.layer3[1].conv1(x)
  with torch.no_grad():
    stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'l31c1')
  if act_quant:
    x = FakeQuantOp.apply(x, num_bits, stats['l31c1']['ema_min'], stats['l31c1']['ema_max'])
  x = model.layer3[1].bn1(x)
  x = model.layer3[1].relu(x)

  conv13weight = model.layer3[1].conv2.weight.data
  model.layer3[1].conv2.weight.data = FakeQuantOp.apply(model.layer3[1].conv2.weight.data, num_bits)
  x = model.layer3[1].conv2(x)
  with torch.no_grad():
    stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'l31c2')
  if act_quant:
    x = FakeQuantOp.apply(x, num_bits, stats['l31c2']['ema_min'], stats['l31c2']['ema_max'])
  x = model.layer3[1].bn2(x)
  ######## layer 4 #######
  conv14weight = model.layer4[0].conv1.weight.data
  model.layer4[0].conv1.weight.data = FakeQuantOp.apply(model.layer4[0].conv1.weight.data, num_bits)
  x = model.layer4[0].conv1(x)
  with torch.no_grad():
    stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'l40c1')
  if act_quant:
    x = FakeQuantOp.apply(x, num_bits, stats['l40c1']['ema_min'], stats['l40c1']['ema_max'])
  x = model.layer4[0].bn1(x)
  x = model.layer4[0].relu(x)

  conv15weight = model.layer4[0].conv2.weight.data
  model.layer4[0].conv2.weight.data = FakeQuantOp.apply(model.layer4[0].conv2.weight.data, num_bits)
  x = model.layer4[0].conv2(x)
  with torch.no_grad():
    stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'l40c2')
  if act_quant:
    x = FakeQuantOp.apply(x, num_bits, stats['l40c2']['ema_min'], stats['l40c2']['ema_max'])
  x = model.layer4[0].bn2(x)
  conv16weight = model.layer4[1].conv1.weight.data
  model.layer4[1].conv1.weight.data = FakeQuantOp.apply(model.layer4[1].conv1.weight.data, num_bits)
  x = model.layer4[1].conv1(x)
  with torch.no_grad():
    stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'l41c1')
  if act_quant:
    x = FakeQuantOp.apply(x, num_bits, stats['l41c1']['ema_min'], stats['l41c1']['ema_max'])
  x = model.layer4[1].bn1(x)
  x = model.layer4[1].relu(x)

  conv17weight = model.layer4[1].conv2.weight.data
  model.layer4[1].conv2.weight.data = FakeQuantOp.apply(model.layer4[1].conv2.weight.data, num_bits)
  x = model.layer4[1].conv2(x)
  with torch.no_grad():
    stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'l41c2')
  if act_quant:
    x = FakeQuantOp.apply(x, num_bits, stats['l41c2']['ema_min'], stats['l41c2']['ema_max'])
  x = model.layer4[1].bn2(x)
  ######## layer ends  #######

  x = x.view(-1, 512) 
  x = model.fc(x)

  with torch.no_grad():
    stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'fc')


  return x, conv1weight, conv2weight, conv3weight, conv4weight, conv5weight, conv6weight, conv7weight, conv8weight, conv9weight, conv10weight, conv11weight, conv12weight, conv13weight, conv14weight, conv15weight, conv16weight,  conv17weight, stats
# Training
# Training
def tinytrain(epoch, trainloader, optimizer, criterion, model, device, stats, act_quant=False, num_bits=8):
    print('\nEpoch: %d' % epoch)
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        #outputs = net(inputs)
        outputs, conv1weight, conv2weight, conv3weight, conv4weight, conv5weight, conv6weight, conv7weight, conv8weight, conv9weight, conv10weight, conv11weight, conv12weight, conv13weight, conv14weight, conv15weight, conv16weight,  conv17weight, stats = quantAwareTrainingForward(model, inputs, stats, num_bits=num_bits, act_quant=act_quant)
        model.conv1.weight.data            = conv1weight
        model.layer1[0].conv1.weight.data  = conv2weight
        model.layer1[0].conv2.weight.data  = conv3weight
        model.layer1[1].conv1.weight.data  = conv4weight
        model.layer1[1].conv2.weight.data  = conv5weight
        model.layer2[0].conv1.weight.data  = conv6weight
        model.layer2[0].conv2.weight.data  = conv7weight
        model.layer2[1].conv1.weight.data  = conv8weight
        model.layer2[1].conv2.weight.data  = conv9weight
        model.layer3[0].conv1.weight.data  = conv10weight
        model.layer3[0].conv2.weight.data  = conv11weight
        model.layer3[1].conv1.weight.data  = conv12weight
        model.layer3[1].conv2.weight.data  = conv13weight
        model.layer4[0].conv1.weight.data  = conv14weight
        model.layer4[0].conv2.weight.data  = conv15weight
        model.layer4[1].conv1.weight.data  = conv16weight
        model.layer4[1].conv2.weight.data  = conv17weight
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                     % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))


def test_qat(epoch, testloader, criterion, model, device, stats, act_quant, num_bits=8):
    global best_acc
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            #outputs = net(inputs)
            outputs, conv1weight, conv2weight, conv3weight, conv4weight, conv5weight, conv6weight, conv7weight, conv8weight, conv9weight, conv10weight, conv11weight, conv12weight, conv13weight, conv14weight, conv15weight, conv16weight,  conv17weight, stats = quantAwareTrainingForward(model, inputs, stats, num_bits=num_bits, act_quant=act_quant)
            model.conv1.weight.data  = conv1weight
            model.layer1[0].conv1.weight.data  = conv2weight
            model.layer1[0].conv2.weight.data  = conv3weight
            model.layer1[1].conv1.weight.data  = conv4weight
            model.layer1[1].conv2.weight.data  = conv5weight
            model.layer2[0].conv1.weight.data  = conv6weight
            model.layer2[0].conv2.weight.data  = conv7weight
            model.layer2[1].conv1.weight.data  = conv8weight
            model.layer2[1].conv2.weight.data  = conv9weight
            model.layer3[0].conv1.weight.data  = conv10weight
            model.layer3[0].conv2.weight.data  = conv11weight
            model.layer3[1].conv1.weight.data  = conv12weight
            model.layer3[1].conv2.weight.data  = conv13weight
            model.layer4[0].conv1.weight.data  = conv14weight
            model.layer4[0].conv2.weight.data  = conv15weight
            model.layer4[1].conv1.weight.data  = conv16weight
            model.layer4[1].conv2.weight.data  = conv17weight
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                         % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

In [None]:
# now we test the QAT trained model's inference accuracy
pretrained_modelqat = "./imgnet_qat.pt"
netqat = resnet18()
sdqat = torch.load(pretrained_modelqat, map_location=torch.device('cpu'))
netqat.load_state_dict(sdqat['net'])
stats = gatherStats(netqat, trainloader)
print(stats) 

In [None]:
criterion = nn.CrossEntropyLoss()
epoch = 1
act_quant = True 
test_qat(epoch, testloader, criterion, netqat, device, stats, act_quant, num_bits=8)

### Ah! Now we get a decent performance with both activation and weight quantization. We will take this model to FPGA¶


## Section 3: DietCNN Inference - Multiplication Free
### The main efficacy of this is in the FPGA implementation.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torchvision import transforms, datasets
from tqdm.auto import tqdm

# For training
import torch.optim as optim
from torch.autograd import Variable
import torchvision
import faiss
import sys
sys.path.insert(1, '../core')
from lut_utils_tiny import *
import warnings
warnings.filterwarnings('ignore')
from patchlib import *
import multiprocessing
from joblib import Parallel, delayed
from evalutils_resnet_imgnet import *
PARALLEL = 30 

In [None]:
# Test method for the DietCNN Tiny Image Net ResNet Network. 

HOWDY = 20000000 

# Test accuracy of symbolic inference
def test_fullsym_acc(model, data_iter, bss=1):
    correct = 0 
    total = 0 
    counter = 0
    model.eval()
    for data in data_iter:
        X, y = data
        if counter > HOWDY:
            break
        output = model.forward(X)
        for idx, i in enumerate(output):
            if torch.argmax(i) == y[idx]:
            #if True:
                correct += 1
        
        counter +=bss 
        total += bss
        if(counter > 0 and counter % bss == 0):
            print("Full symbolic model test accuracy DietCNN :{}% ".format(100*round(correct/total, 4)))
    return round(correct/total, 4)

In [None]:
# DietCNN HyperParameters - 3 main 
# 1. Image  and all activation symbols 
# 2 & 3. Symbols for CONV and FC layers dictionary 

index = faiss.read_index("./kmeans_alexnet_c1_k1_s1_512_repeat2_v10.index")

n_clusters=512

# using a single pixel patch as of now
conv_patch_size = (1, 1)
patch_size = (1, 1)
all_patch_size = (1, 1)

patch_stride = 1
# Hyperparameters 2 & 3. Symbols for CONV and FC layers dictionary 
n_cluster_conv_filters = 256
n_cluster_fc_filters = 128
conv_stride = 1

# this is the reverse dictionary - symbol to patch 
centroid_lut = index.reconstruct_n(0, n_clusters)

import pickle
# Load the CONV and FC dictionaries and the LUT that are created already
with open('imgnet_conv_flt.index', "rb") as f:
    filter_index_conv = pickle.load(f)
with open('imgnet_fc_flt.index', "rb") as f:
    filter_index_fc = pickle.load(f)
fc_lut = np.genfromtxt('./imgnet_fc_lut.txt', delimiter=',',dtype=np.int16)
conv_lut = np.genfromtxt('./imgnet_conv_lut.txt', delimiter=',',dtype=np.int16)
add_lut = np.genfromtxt('./imgnet_add_lut.txt', delimiter=',',dtype=np.int16)
relu_lut = np.genfromtxt('./imgnet_relu_lut.txt', delimiter=',',dtype=np.int16)

# this is the creation of symbolic model.
# All these steps are need in the desktop implementation to sync with PyTorch inference
# For FPGA implementation the DietCNN models are quite simple 

print(" Symbolic model loading started...")
t = time.process_time()
# Sorry did not get time to change the name from VGG :(
netsym = vgg_sym(net,sd, filter_index_conv, filter_index_fc, conv_lut, fc_lut, add_lut, 
                  relu_lut, n_clusters, index, centroid_lut, patch_size, patch_stride)

elapsed_time3 = time.process_time() - t
print("Symbolic model loading completed in:",elapsed_time3)
netsym.eval()
start_t = time.time()  
acc = test_fullsym_acc(netsym, val_loader)
end = time.time()
print("elapsed time for symbolic inference:", end - start_t) 