In [38]:
from __future__ import print_function, division
from vgg_sym import *
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os
from torchvision import transforms
from torch.utils.data import DataLoader, TensorDataset, Dataset
from torchvision import models, datasets

# For training
from torch.autograd import Variable
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import time
import os
import copy
cudnn.benchmark = True
best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch
from collections import namedtuple
import torch
import torch.nn as nn

In [39]:
QTensor = namedtuple('QTensor', ['tensor', 'scale', 'zero_point'])

def calcScaleZeroPoint(min_val, max_val,num_bits=8):
  # Calc Scale and zero point of next 
  qmin = 0.
  qmax = 2.**num_bits - 1.

  scale = (max_val - min_val) / (qmax - qmin)

  initial_zero_point = qmin - min_val / scale
  
  zero_point = 0
  if initial_zero_point < qmin:
      zero_point = qmin
  elif initial_zero_point > qmax:
      zero_point = qmax
  else:
      zero_point = initial_zero_point

  zero_point = int(zero_point)

  return scale, zero_point

def calcScaleZeroPointSym(min_val, max_val,num_bits=8):
  
  # Calc Scale 
  max_val = max(abs(min_val), abs(max_val))
  qmin = 0.
  qmax = 2.**(num_bits-1) - 1.

  scale = max_val / qmax

  return scale, 0

def quantize_tensor(x, num_bits=8, min_val=None, max_val=None):
    
    if not min_val and not max_val: 
      min_val, max_val = x.min(), x.max()

    qmin = 0.
    qmax = 2.**num_bits - 1.

    scale, zero_point = calcScaleZeroPoint(min_val, max_val, num_bits)
    q_x = zero_point + x / scale
    q_x.clamp_(qmin, qmax).round_()
    q_x = q_x.round().byte()
    
    return QTensor(tensor=q_x, scale=scale, zero_point=zero_point)

def dequantize_tensor(q_x):
    return q_x.scale * (q_x.tensor.float() - q_x.zero_point)

def quantize_tensor_sym(x, num_bits=8, min_val=None, max_val=None):
    
    if not min_val and not max_val: 
      min_val, max_val = x.min(), x.max()

    max_val = max(abs(min_val), abs(max_val))
    qmin = 0.
    qmax = 2.**(num_bits-1) - 1.

    scale = max_val / qmax   

    q_x = x/scale

    q_x.clamp_(-qmax, qmax).round_()
    q_x = q_x.round()
    return QTensor(tensor=q_x, scale=scale, zero_point=0)

def dequantize_tensor_sym(q_x):
    return q_x.scale * (q_x.tensor.float())

def quantizeLayer(x, layer, stat, scale_x, zp_x, vis=False, axs=None, X=None, y=None, sym=False, num_bits=8):
  # for both conv and linear layers

  # cache old values
  W = layer.weight.data
  B = layer.bias.data

  # WEIGHTS SIMULATED QUANTISED

  # quantise weights, activations are already quantised
  if sym:
    w = quantize_tensor_sym(layer.weight.data,num_bits=num_bits) 
    b = quantize_tensor_sym(layer.bias.data,num_bits=num_bits)
  else:
    w = quantize_tensor(layer.weight.data, num_bits=num_bits) 
    b = quantize_tensor(layer.bias.data, num_bits=num_bits)

  layer.weight.data = w.tensor.float()
  layer.bias.data = b.tensor.float()

  ## END WEIGHTS QUANTISED SIMULATION


  if vis:
    axs[X,y].set_xlabel("Visualising weights of layer: ")
    visualise(layer.weight.data, axs[X,y])

  # QUANTISED OP, USES SCALE AND ZERO POINT TO DO LAYER FORWARD PASS. (How does backprop change here ?)
  # This is Quantisation Arithmetic
  scale_w = w.scale
  zp_w = w.zero_point
  scale_b = b.scale
  zp_b = b.zero_point
  
  if sym:
    scale_next, zero_point_next = calcScaleZeroPointSym(min_val=stat['min'], max_val=stat['max'])
  else:
    scale_next, zero_point_next = calcScaleZeroPoint(min_val=stat['min'], max_val=stat['max'])

  # Preparing input by saturating range to num_bits range.
  if sym:
    X = x.float()
    layer.weight.data = ((scale_x * scale_w) / scale_next)*(layer.weight.data)
    layer.bias.data = (scale_b/scale_next)*(layer.bias.data)
  else:
    X = x.float() - zp_x
    layer.weight.data = ((scale_x * scale_w) / scale_next)*(layer.weight.data - zp_w)
    layer.bias.data = (scale_b/scale_next)*(layer.bias.data + zp_b)

  # All int computation
  if sym:  
    x = (layer(X)) 
  else:
    x = (layer(X)) + zero_point_next 
  
  # cast to int
  x.round_()

  # Perform relu too
  x = F.relu(x)

  # Reset weights for next forward pass
  layer.weight.data = W
  layer.bias.data = B
  
  return x, scale_next, zero_point_next

# Get Min and max of x tensor, and stores it
def updateStats(x, stats, key):
  max_val, _ = torch.max(x, dim=1)
  min_val, _ = torch.min(x, dim=1)

  # add ema calculation

  if key not in stats:
    stats[key] = {'max': max_val.sum(), 'min': min_val.sum(), 'total': 1}
  else:
    stats[key]['max'] += max_val.sum().item()
    stats[key]['min'] += min_val.sum().item()
    if 'total' in stats[key]:
        stats[key]['total'] += 1
    else:
        stats[key]['total'] = 1
  
  weighting = 2.0 / (stats[key]['total']) + 1

  if 'ema_min' in stats[key]:
    stats[key]['ema_min'] = weighting*(min_val.mean().item()) + (1- weighting) * stats[key]['ema_min']
  else:
    stats[key]['ema_min'] = weighting*(min_val.mean().item())

  if 'ema_max' in stats[key]:
    stats[key]['ema_max'] = weighting*(max_val.mean().item()) + (1- weighting) * stats[key]['ema_max']
  else: 
    stats[key]['ema_max'] = weighting*(max_val.mean().item())

  stats[key]['min_val'] = stats[key]['min']/ stats[key]['total']
  stats[key]['max_val'] = stats[key]['max']/ stats[key]['total']
  
  return stats

# Reworked Forward Pass to access activation Stats through updateStats function
def gatherActivationStats(model, x, stats):
  stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'conv1')
  x = model.features[1](model.features[0](x))
  stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'conv2')
  x =  model.features[3](model.features[2](x))
  stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'conv3')
  x = model.features[5](model.features[4](x))
  stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'conv4')
  x = model.features[7](model.features[6](x))
  stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'conv5')
  x = model.features[9](model.features[8](x))
  stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'conv6')
  x = model.features[11](model.features[10](x))
  stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'conv7')
  x = model.features[13](model.features[12](x))
  stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'conv8')
  x = model.features[15](model.features[14](x))

  #x = x.view(x.size(0), -1)  
  x = x.view(-1, 512) 
  
  stats = updateStats(x, stats, 'fc')

  x = model.classifier(x)

  return stats

# Entry function to get stats of all functions.
def gatherStats(model, test_loader):
    device = 'cpu'
    
    model.eval()
    test_loss = 0
    correct = 0
    stats = {}
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            stats = gatherActivationStats(model, data, stats)
    
    final_stats = {}
    for key, value in stats.items():
      final_stats[key] = { "max" : value["max"] / value["total"], "min" : value["min"] / value["total"], "ema_min": value["ema_min"], "ema_max": value["ema_max"] }
    return final_stats

def quantForward(model, x, stats, vis=False, axs=None, sym=False, num_bits=8):
  X = 0
  y = 0
  # Quantise before inputting into incoming layers
  if sym:
    x = quantize_tensor_sym(x, min_val=stats['conv1']['min'], max_val=stats['conv1']['max'], num_bits=num_bits)
  else:
    x = quantize_tensor(x, min_val=stats['conv1']['min'], max_val=stats['conv1']['max'], num_bits=num_bits)

    # Quantise before inputting into incoming layers
  if sym:
    x = quantize_tensor_sym(x, min_val=stats['conv1']['min'], max_val=stats['conv1']['max'], num_bits=nb)
  else:
    x = quantize_tensor(x, min_val=stats['conv1']['min'], max_val=stats['conv1']['max'], num_bits=nb)

  x, scale_next, zero_point_next = quantizeLayer(x.tensor, model.features[0], stats['conv2'], x.scale, x.zero_point)
  #x = model.features[1](x)
  x, scale_next, zero_point_next = quantizeLayer(x, model.features[2], stats['conv3'], scale_next, zero_point_next)
  #x = model.features[3](x)
  x, scale_next, zero_point_next = quantizeLayer(x, model.features[4], stats['conv4'], scale_next, zero_point_next)
  #x = model.features[5](x)
  x, scale_next, zero_point_next = quantizeLayer(x, model.features[6], stats['conv5'], scale_next, zero_point_next)
  #x = model.features[7](x)
  x, scale_next, zero_point_next = quantizeLayer(x, model.features[8], stats['conv6'], scale_next, zero_point_next)
  #x = model.features[9](x)
  x, scale_next, zero_point_next = quantizeLayer(x, model.features[10], stats['conv7'], scale_next, zero_point_next)
  #x = model.features[11](x)
  x, scale_next, zero_point_next = quantizeLayer(x, model.features[12], stats['conv8'], scale_next, zero_point_next)
  #x = model.features[13](x)
  x, scale_next, zero_point_next = quantizeLayer(x, model.features[14], stats['fc'], scale_next, zero_point_next)
  #x = model.features[15](x)
    
  
  #x = x.view(x.size(0), -1)  
  x = x.view(-1, 512)   
  
  
  # Back to dequant for final layer
  x = dequantize_tensor(QTensor(tensor=x, scale=scale_next, zero_point=zero_point_next))
   
  x = model.classifier(x)

  return x

import torch

class FakeQuantOp(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, num_bits=8, min_val=None, max_val=None):
        x = quantize_tensor(x,num_bits=num_bits, min_val=min_val, max_val=max_val)
        x = dequantize_tensor(x)
        return x

    @staticmethod
    def backward(ctx, grad_output):
        # straight through estimator
        return grad_output, None, None, None

def quantAwareTrainingForward(model, x, stats, vis=False, axs=None, sym=False, num_bits=8, act_quant=False):

  #print(x.shape)
  #print(model.features[0].weight.data.shape) 
  #x = model.features[0](x)
  #x = model.features[1](x)
  #stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'conv2')
  #x =  model.features[3](model.features[2](x))
  #stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'conv3')
  #x = model.features[5](model.features[4](x))
  #stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'conv4')
  #x = model.features[7](model.features[6](x))
  #stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'conv5')
  #x = model.features[9](model.features[8](x))
  #stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'conv6')
  #x = model.features[11](model.features[10](x))
  #stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'conv7')
  #x = model.features[13](model.features[12](x))
  #stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'conv8')
  #x = model.features[15](model.features[14](x))
  ##x = x.view(x.size(0), -1)  
  #x = x.view(-1, 512) 
  #stats = updateStats(x, stats, 'fc')
  #x = model.classifier(x)

  conv1weight = model.features[0].weight.data
  model.features[0].weight.data = FakeQuantOp.apply(model.features[0].weight.data, num_bits)
  x = model.features[1](model.features[0](x))

  with torch.no_grad():
    stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'conv1')

  if act_quant:
    x = FakeQuantOp.apply(x, num_bits, stats['conv1']['ema_min'], stats['conv1']['ema_max'])

  conv2weight = model.features[2].weight.data
  model.features[2].weight.data = FakeQuantOp.apply(model.features[2].weight.data, num_bits)
  x = model.features[3](model.features[2](x))

  with torch.no_grad():
    stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'conv2')

  if act_quant:
    x = FakeQuantOp.apply(x, num_bits, stats['conv2']['ema_min'], stats['conv2']['ema_max'])

  conv3weight = model.features[4].weight.data
  model.features[4].weight.data = FakeQuantOp.apply(model.features[4].weight.data, num_bits)
  x = model.features[5](model.features[4](x))

  with torch.no_grad():
    stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'conv3')

  if act_quant:
    x = FakeQuantOp.apply(x, num_bits, stats['conv3']['ema_min'], stats['conv3']['ema_max'])


  conv4weight = model.features[6].weight.data
  model.features[6].weight.data = FakeQuantOp.apply(model.features[6].weight.data, num_bits)
  x = model.features[7](model.features[6](x))

  with torch.no_grad():
    stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'conv4')

  if act_quant:
    x = FakeQuantOp.apply(x, num_bits, stats['conv4']['ema_min'], stats['conv4']['ema_max'])


  conv5weight = model.features[8].weight.data
  model.features[8].weight.data = FakeQuantOp.apply(model.features[8].weight.data, num_bits)
  x = model.features[9](model.features[8](x))

  with torch.no_grad():
    stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'conv5')

  if act_quant:
    x = FakeQuantOp.apply(x, num_bits, stats['conv5']['ema_min'], stats['conv5']['ema_max'])



  conv6weight = model.features[10].weight.data
  model.features[10].weight.data = FakeQuantOp.apply(model.features[10].weight.data, num_bits)
  x = model.features[11](model.features[10](x))

  with torch.no_grad():
    stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'conv6')

  if act_quant:
    x = FakeQuantOp.apply(x, num_bits, stats['conv6']['ema_min'], stats['conv6']['ema_max'])


  conv7weight = model.features[12].weight.data
  model.features[12].weight.data = FakeQuantOp.apply(model.features[12].weight.data, num_bits)
  x = model.features[13](model.features[12](x))

  with torch.no_grad():
    stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'conv7')

  if act_quant:
    x = FakeQuantOp.apply(x, num_bits, stats['conv7']['ema_min'], stats['conv7']['ema_max'])


  conv8weight = model.features[14].weight.data
  model.features[14].weight.data = FakeQuantOp.apply(model.features[14].weight.data, num_bits)
  x = model.features[15](model.features[14](x))

  with torch.no_grad():
    stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'conv8')

  if act_quant:
    x = FakeQuantOp.apply(x, num_bits, stats['conv8']['ema_min'], stats['conv8']['ema_max'])

  x = x.view(-1, 512) 
  x = model.classifier(x)

  with torch.no_grad():
    stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'fc')


  return x, conv1weight, conv2weight, conv3weight, conv4weight, conv5weight, conv6weight, conv7weight, conv8weight, stats

# Training
def train(epoch, trainloader, optimizer, criterion, model, device, stats, act_quant=False, num_bits=8):
    print('\nEpoch: %d' % epoch)
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        #outputs = net(inputs)
        outputs, conv1weight, conv2weight, conv3weight, conv4weight, conv5weight, conv6weight, conv7weight, conv8weight, stats = quantAwareTrainingForward(model, inputs, stats, num_bits=num_bits, act_quant=act_quant)
        model.features[0].weight.data   = conv1weight
        model.features[2].weight.data   = conv2weight
        model.features[4].weight.data   = conv3weight
        model.features[6].weight.data   = conv4weight
        model.features[8].weight.data   = conv5weight
        model.features[10].weight.data  = conv6weight
        model.features[12].weight.data  = conv7weight
        model.features[14].weight.data  = conv8weight
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                     % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))


def test(epoch, testloader, criterion, model, device, stats, act_quant, num_bits=8):
    global best_acc
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            #outputs = net(inputs)
            outputs, conv1weight, conv2weight, conv3weight, conv4weight, conv5weight, conv6weight, conv7weight, conv8weight, stats = quantAwareTrainingForward(model, inputs, stats, num_bits=num_bits, act_quant=act_quant)
            model.features[0].weight.data   = conv1weight
            model.features[2].weight.data   = conv2weight
            model.features[4].weight.data   = conv3weight
            model.features[6].weight.data   = conv4weight
            model.features[8].weight.data   = conv5weight
            model.features[10].weight.data  = conv6weight
            model.features[12].weight.data  = conv7weight
            model.features[14].weight.data  = conv8weight
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            print(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                         % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

#     # Save checkpoint.
#     acc = 100.*correct/total
#     if acc > best_acc:
#         print('Saving model with validation acc, loss: ',acc,' ,',test_loss)
#         state = {
#             'net': net.state_dict(),
#             'acc': acc,
#             'epoch': epoch,
#         }
#         torch.save(state, './cifar_qat.pt')
#         best_acc = acc



In [40]:
torch.manual_seed(0)
np.random.seed(0)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(" Data loading started...")
bs = 16
num_bits=8

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    #transforms.AutoAugment(transforms.AutoAugmentPolicy.CIFAR10),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(root='../../../formal_pruning/dataset', train=True, download=False, transform=transform_train)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=bs, shuffle=True)
testset = torchvision.datasets.CIFAR10(root='../../../formal_pruning/dataset', train=False, download=False, transform=transform_test)
test_loader = torch.utils.data.DataLoader(testset, batch_size=bs, shuffle=False)
num_classes = 10
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

#pretrained_modelqat = "./cifar_vgg_sym_v3.pt"
pretrained_modelqat = "./cifar_qat.pt"
netqat = VGG('VGG11')
sdqat = torch.load(pretrained_modelqat, map_location=torch.device('cpu'))
netqat.load_state_dict(sdqat['net'])
stats = gatherStats(netqat, test_loader)
print(stats) 
criterion = nn.CrossEntropyLoss()
epoch = 1
act_quant = True 
test(epoch, test_loader, criterion, netqat, device, stats, act_quant, num_bits=8)

 Data loading started...
{'conv1': {'max': tensor(38.0929), 'min': tensor(-34.3413), 'ema_min': -2.2328857964704634, 'ema_max': 2.5693619505756042}, 'conv2': {'max': tensor(104.3430), 'min': tensor(0.), 'ema_min': 0.0, 'ema_max': 7.106570094804752}, 'conv3': {'max': tensor(107.0603), 'min': tensor(0.), 'ema_min': 0.0, 'ema_max': 6.664907399560724}, 'conv4': {'max': tensor(74.4338), 'min': tensor(0.), 'ema_min': 0.0, 'ema_max': 5.193922243396578}, 'conv5': {'max': tensor(52.3959), 'min': tensor(0.), 'ema_min': 0.0, 'ema_max': 3.8755614585320126}, 'conv6': {'max': tensor(24.7309), 'min': tensor(0.), 'ema_min': 0.0, 'ema_max': 1.6148410296951774}, 'conv7': {'max': tensor(15.9820), 'min': tensor(0.), 'ema_min': 0.0, 'ema_max': 0.9670984278935988}, 'conv8': {'max': tensor(31.2128), 'min': tensor(0.), 'ema_min': 0.0, 'ema_max': 1.8251638992648405}, 'fc': {'max': tensor(62.8690), 'min': tensor(0.), 'ema_min': 0.0, 'ema_max': 3.908302091814401}}
0 625 Loss: 2.309 | Acc: 12.500% (2/16)
1 625 Lo

160 625 Loss: 0.816 | Acc: 88.354% (2276/2576)
161 625 Loss: 0.814 | Acc: 88.349% (2290/2592)
162 625 Loss: 0.820 | Acc: 88.344% (2304/2608)
163 625 Loss: 0.817 | Acc: 88.377% (2319/2624)
164 625 Loss: 0.820 | Acc: 88.371% (2333/2640)
165 625 Loss: 0.816 | Acc: 88.366% (2347/2656)
166 625 Loss: 0.811 | Acc: 88.436% (2363/2672)
167 625 Loss: 0.807 | Acc: 88.504% (2379/2688)
168 625 Loss: 0.804 | Acc: 88.499% (2393/2704)
169 625 Loss: 0.799 | Acc: 88.566% (2409/2720)
170 625 Loss: 0.796 | Acc: 88.596% (2424/2736)
171 625 Loss: 0.794 | Acc: 88.590% (2438/2752)
172 625 Loss: 0.791 | Acc: 88.620% (2453/2768)
173 625 Loss: 0.789 | Acc: 88.649% (2468/2784)
174 625 Loss: 0.793 | Acc: 88.643% (2482/2800)
175 625 Loss: 0.792 | Acc: 88.636% (2496/2816)
176 625 Loss: 0.794 | Acc: 88.665% (2511/2832)
177 625 Loss: 0.795 | Acc: 88.624% (2524/2848)
178 625 Loss: 0.791 | Acc: 88.687% (2540/2864)
179 625 Loss: 0.788 | Acc: 88.715% (2555/2880)
180 625 Loss: 0.787 | Acc: 88.709% (2569/2896)
181 625 Loss:

337 625 Loss: 0.711 | Acc: 89.312% (4830/5408)
338 625 Loss: 0.712 | Acc: 89.307% (4844/5424)
339 625 Loss: 0.712 | Acc: 89.320% (4859/5440)
340 625 Loss: 0.714 | Acc: 89.296% (4872/5456)
341 625 Loss: 0.714 | Acc: 89.291% (4886/5472)
342 625 Loss: 0.712 | Acc: 89.322% (4902/5488)
343 625 Loss: 0.710 | Acc: 89.335% (4917/5504)
344 625 Loss: 0.712 | Acc: 89.330% (4931/5520)
345 625 Loss: 0.712 | Acc: 89.324% (4945/5536)
346 625 Loss: 0.715 | Acc: 89.319% (4959/5552)
347 625 Loss: 0.715 | Acc: 89.314% (4973/5568)
348 625 Loss: 0.713 | Acc: 89.345% (4989/5584)
349 625 Loss: 0.712 | Acc: 89.357% (5004/5600)
350 625 Loss: 0.715 | Acc: 89.334% (5017/5616)
351 625 Loss: 0.713 | Acc: 89.364% (5033/5632)
352 625 Loss: 0.715 | Acc: 89.341% (5046/5648)
353 625 Loss: 0.713 | Acc: 89.354% (5061/5664)
354 625 Loss: 0.712 | Acc: 89.349% (5075/5680)
355 625 Loss: 0.711 | Acc: 89.343% (5089/5696)
356 625 Loss: 0.709 | Acc: 89.373% (5105/5712)
357 625 Loss: 0.709 | Acc: 89.385% (5120/5728)
358 625 Loss:

513 625 Loss: 0.673 | Acc: 89.664% (7374/8224)
514 625 Loss: 0.673 | Acc: 89.672% (7389/8240)
515 625 Loss: 0.672 | Acc: 89.692% (7405/8256)
516 625 Loss: 0.671 | Acc: 89.688% (7419/8272)
517 625 Loss: 0.673 | Acc: 89.660% (7431/8288)
518 625 Loss: 0.674 | Acc: 89.632% (7443/8304)
519 625 Loss: 0.673 | Acc: 89.627% (7457/8320)
520 625 Loss: 0.674 | Acc: 89.599% (7469/8336)
521 625 Loss: 0.673 | Acc: 89.607% (7484/8352)
522 625 Loss: 0.674 | Acc: 89.579% (7496/8368)
523 625 Loss: 0.673 | Acc: 89.575% (7510/8384)
524 625 Loss: 0.672 | Acc: 89.583% (7525/8400)
525 625 Loss: 0.671 | Acc: 89.579% (7539/8416)
526 625 Loss: 0.671 | Acc: 89.575% (7553/8432)
527 625 Loss: 0.670 | Acc: 89.583% (7568/8448)
528 625 Loss: 0.671 | Acc: 89.579% (7582/8464)
529 625 Loss: 0.671 | Acc: 89.564% (7595/8480)
530 625 Loss: 0.671 | Acc: 89.560% (7609/8496)
531 625 Loss: 0.670 | Acc: 89.579% (7625/8512)
532 625 Loss: 0.670 | Acc: 89.587% (7640/8528)
533 625 Loss: 0.669 | Acc: 89.595% (7655/8544)
534 625 Loss:

In [41]:
torch.manual_seed(0)
np.random.seed(0)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(" Data loading started...")
bs = 16
num_bits=8

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    #transforms.AutoAugment(transforms.AutoAugmentPolicy.CIFAR10),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(root='../../../formal_pruning/dataset', train=True, download=False, transform=transform_train)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=bs, shuffle=True)
testset = torchvision.datasets.CIFAR10(root='../../../formal_pruning/dataset', train=False, download=False, transform=transform_test)
test_loader = torch.utils.data.DataLoader(testset, batch_size=bs, shuffle=False)
num_classes = 10
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

#pretrained_modelqat = "./cifar_vgg_sym_v3.pt"
pretrained_modelqat = "./cifar_qat.pt"
netqat = VGG('VGG11')
sdqat = torch.load(pretrained_modelqat, map_location=torch.device('cpu'))
netqat.load_state_dict(sdqat['net'])
stats = gatherStats(netqat, test_loader)
print(stats) 
criterion = nn.CrossEntropyLoss()
epoch = 1
act_quant = True 
test(epoch, test_loader, criterion, netqat, device, stats, act_quant, num_bits=4)

 Data loading started...
{'conv1': {'max': tensor(38.0929), 'min': tensor(-34.3413), 'ema_min': -2.2328857964704634, 'ema_max': 2.5693619505756042}, 'conv2': {'max': tensor(104.3430), 'min': tensor(0.), 'ema_min': 0.0, 'ema_max': 7.106570094804752}, 'conv3': {'max': tensor(107.0603), 'min': tensor(0.), 'ema_min': 0.0, 'ema_max': 6.664907399560724}, 'conv4': {'max': tensor(74.4338), 'min': tensor(0.), 'ema_min': 0.0, 'ema_max': 5.193922243396578}, 'conv5': {'max': tensor(52.3959), 'min': tensor(0.), 'ema_min': 0.0, 'ema_max': 3.8755614585320126}, 'conv6': {'max': tensor(24.7309), 'min': tensor(0.), 'ema_min': 0.0, 'ema_max': 1.6148410296951774}, 'conv7': {'max': tensor(15.9820), 'min': tensor(0.), 'ema_min': 0.0, 'ema_max': 0.9670984278935988}, 'conv8': {'max': tensor(31.2128), 'min': tensor(0.), 'ema_min': 0.0, 'ema_max': 1.8251638992648405}, 'fc': {'max': tensor(62.8690), 'min': tensor(0.), 'ema_min': 0.0, 'ema_max': 3.908302091814401}}
0 625 Loss: 2.309 | Acc: 12.500% (2/16)
1 625 Lo

161 625 Loss: 1.313 | Acc: 78.858% (2044/2592)
162 625 Loss: 1.320 | Acc: 78.758% (2054/2608)
163 625 Loss: 1.319 | Acc: 78.735% (2066/2624)
164 625 Loss: 1.319 | Acc: 78.750% (2079/2640)
165 625 Loss: 1.325 | Acc: 78.727% (2091/2656)
166 625 Loss: 1.317 | Acc: 78.855% (2107/2672)
167 625 Loss: 1.311 | Acc: 78.869% (2120/2688)
168 625 Loss: 1.305 | Acc: 78.957% (2135/2704)
169 625 Loss: 1.300 | Acc: 79.044% (2150/2720)
170 625 Loss: 1.305 | Acc: 79.094% (2164/2736)
171 625 Loss: 1.302 | Acc: 79.142% (2178/2752)
172 625 Loss: 1.302 | Acc: 79.118% (2190/2768)
173 625 Loss: 1.307 | Acc: 79.059% (2201/2784)
174 625 Loss: 1.313 | Acc: 79.036% (2213/2800)
175 625 Loss: 1.314 | Acc: 79.048% (2226/2816)
176 625 Loss: 1.322 | Acc: 79.061% (2239/2832)
177 625 Loss: 1.330 | Acc: 78.968% (2249/2848)
178 625 Loss: 1.331 | Acc: 78.911% (2260/2864)
179 625 Loss: 1.325 | Acc: 78.958% (2274/2880)
180 625 Loss: 1.321 | Acc: 79.006% (2288/2896)
181 625 Loss: 1.317 | Acc: 79.087% (2303/2912)
182 625 Loss:

337 625 Loss: 1.272 | Acc: 79.882% (4320/5408)
338 625 Loss: 1.270 | Acc: 79.904% (4334/5424)
339 625 Loss: 1.267 | Acc: 79.945% (4349/5440)
340 625 Loss: 1.266 | Acc: 79.930% (4361/5456)
341 625 Loss: 1.266 | Acc: 79.916% (4373/5472)
342 625 Loss: 1.264 | Acc: 79.920% (4386/5488)
343 625 Loss: 1.265 | Acc: 79.906% (4398/5504)
344 625 Loss: 1.268 | Acc: 79.873% (4409/5520)
345 625 Loss: 1.274 | Acc: 79.805% (4418/5536)
346 625 Loss: 1.275 | Acc: 79.827% (4432/5552)
347 625 Loss: 1.275 | Acc: 79.813% (4444/5568)
348 625 Loss: 1.273 | Acc: 79.782% (4455/5584)
349 625 Loss: 1.273 | Acc: 79.786% (4468/5600)
350 625 Loss: 1.276 | Acc: 79.754% (4479/5616)
351 625 Loss: 1.276 | Acc: 79.741% (4491/5632)
352 625 Loss: 1.281 | Acc: 79.674% (4500/5648)
353 625 Loss: 1.282 | Acc: 79.643% (4511/5664)
354 625 Loss: 1.282 | Acc: 79.613% (4522/5680)
355 625 Loss: 1.283 | Acc: 79.635% (4536/5696)
356 625 Loss: 1.280 | Acc: 79.657% (4550/5712)
357 625 Loss: 1.280 | Acc: 79.644% (4562/5728)
358 625 Loss:

513 625 Loss: 1.274 | Acc: 79.669% (6552/8224)
514 625 Loss: 1.273 | Acc: 79.684% (6566/8240)
515 625 Loss: 1.272 | Acc: 79.675% (6578/8256)
516 625 Loss: 1.271 | Acc: 79.691% (6592/8272)
517 625 Loss: 1.274 | Acc: 79.657% (6602/8288)
518 625 Loss: 1.276 | Acc: 79.612% (6611/8304)
519 625 Loss: 1.275 | Acc: 79.615% (6624/8320)
520 625 Loss: 1.275 | Acc: 79.619% (6637/8336)
521 625 Loss: 1.277 | Acc: 79.586% (6647/8352)
522 625 Loss: 1.279 | Acc: 79.565% (6658/8368)
523 625 Loss: 1.279 | Acc: 79.556% (6670/8384)
524 625 Loss: 1.281 | Acc: 79.512% (6679/8400)
525 625 Loss: 1.279 | Acc: 79.515% (6692/8416)
526 625 Loss: 1.279 | Acc: 79.530% (6706/8432)
527 625 Loss: 1.281 | Acc: 79.522% (6718/8448)
528 625 Loss: 1.281 | Acc: 79.537% (6732/8464)
529 625 Loss: 1.281 | Acc: 79.540% (6745/8480)
530 625 Loss: 1.284 | Acc: 79.508% (6755/8496)
531 625 Loss: 1.282 | Acc: 79.535% (6770/8512)
532 625 Loss: 1.283 | Acc: 79.538% (6783/8528)
533 625 Loss: 1.282 | Acc: 79.553% (6797/8544)
534 625 Loss: