# Initialization

## Import Libraries

In [1]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms

In [2]:
from collections import namedtuple

In [3]:
# make sure GPU is being used 
torch.cuda.current_device() 
torch.cuda.device(0)
torch.cuda.get_device_name(0)

'Tesla K80'

# Model

In [6]:
model = torchvision.models.resnet18(pretrained=True)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


  0%|          | 0.00/44.7M [00:00<?, ?B/s]

In [7]:
model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

## CIFAR-10

In [None]:
class Net(nn.Module):
    def __init__(self):
      
        super(Net, self).__init__()
        num_channels = 3
          
        self.conv1 = nn.Conv2d(num_channels, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(1250, 500)
        self.flatten_shape = 1250

        self.fc2 = nn.Linear(500, 10)
        
      
    def forward(self, x, vis=False, axs=None):
        X = 0
        Y = 0

        if vis:
            axs[X,Y].set_xlabel('Entry into network, input distribution visualised below: ')
            visualise(x, axs[X,Y])
            
            axs[X,Y+1].set_xlabel("Visualising weights of conv 1 layer: ")
            visualise(self.conv1.weight.data, axs[X,Y+1])


        x = F.relu(self.conv1(x))

        if vis:
            axs[X,Y+2].set_xlabel('Output after conv1 visualised below: ')
            visualise(x,axs[X,Y+2])
            
            axs[X,Y+3].set_xlabel("Visualising weights of conv 2 layer: ")
            visualise(self.conv2.weight.data, axs[X,Y+3])

        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))

        if vis:
            axs[X,Y+4].set_xlabel('Output after conv2 visualised below: ')
            visualise(x,axs[X,Y+4])
            
            axs[X+1,Y].set_xlabel("Visualising weights of fc 1 layer: ")
            visualise(self.fc1.weight.data, axs[X+1,Y])

        x = F.max_pool2d(x, 2, 2)  
        x = x.view(-1, self.flatten_shape)
        x = F.relu(self.fc1(x))

        if vis:
            axs[X+1,Y+1].set_xlabel('Output after fc1 visualised below: ')
            visualise(x,axs[X+1,Y+1])
            
            axs[X+1,Y+2].set_xlabel("Visualising weights of fc 2 layer: ")
            visualise(self.fc2.weight.data, axs[X+1,Y+2])

        x = self.fc2(x)

        if vis:
            axs[X+1,Y+3].set_xlabel('Output after fc2 visualised below: ')
            visualise(x,axs[X+1,Y+3])

        return F.log_softmax(x, dim=1)
    

# Post Training Quantization

## Train Test Loop Functions

In [None]:
def test(args, model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [None]:
def train(args, model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        
   
        if batch_idx % args["log_interval"] == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

In [None]:
def main():
 
    batch_size = 64
    test_batch_size = 64
    epochs = 10
    lr = 0.01
    momentum = 0.5
    seed = 1
    log_interval = 500
    save_model = False
    no_cuda = False
    
    use_cuda = not no_cuda and torch.cuda.is_available()

    torch.manual_seed(seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

    transform = transforms.Compose(
        [transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    
    trainset = datasets.CIFAR10(root='./dataCifar', train=True,
                                            download=True, transform=transform)
    train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                            shuffle=True, num_workers=2)
    
    testset = datasets.CIFAR10(root='./dataCifar', train=False,
                                        download=True, transform=transform)
    
    test_loader = torch.utils.data.DataLoader(testset, batch_size=test_batch_size,
                                            shuffle=False, num_workers=2)
          
  
    model = Net().to(device)
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
    args = {}
    args["log_interval"] = log_interval
    for epoch in range(1, epochs + 1):
        train(args, model, device, train_loader, optimizer, epoch)
        test(args, model, device, test_loader)

    if (save_model):
        torch.save(model.state_dict(),"mnist_cnn.pt")
    
    return model

In [None]:
model = main()

Files already downloaded and verified
Files already downloaded and verified

Test set: Average loss: 1.6384, Accuracy: 4033/10000 (40%)


Test set: Average loss: 1.4415, Accuracy: 4781/10000 (48%)


Test set: Average loss: 1.3010, Accuracy: 5353/10000 (54%)


Test set: Average loss: 1.2959, Accuracy: 5309/10000 (53%)


Test set: Average loss: 1.1161, Accuracy: 6033/10000 (60%)


Test set: Average loss: 1.0464, Accuracy: 6352/10000 (64%)


Test set: Average loss: 1.0252, Accuracy: 6426/10000 (64%)


Test set: Average loss: 1.0219, Accuracy: 6465/10000 (65%)


Test set: Average loss: 0.9444, Accuracy: 6775/10000 (68%)


Test set: Average loss: 0.9150, Accuracy: 6840/10000 (68%)



## Quantisation of Network

### Quantization Functions

In [None]:
QTensor = namedtuple('QTensor', ['tensor', 'scale', 'zero_point'])

In [None]:
def calcScaleZeroPoint(min_val, max_val,num_bits):
    # Calc Scale and zero point of next
    qmin = 0.
    qmax = 2.**num_bits - 1.
    
    scale_next = (max_val - min_val) / (qmax - qmin)
    initial_zero_point = qmin - min_val / scale_next
  
    if initial_zero_point < qmin:
        zero_point_next = qmin
    elif initial_zero_point > qmax:
        zero_point_next = qmax
    else:
        zero_point_next = initial_zero_point
        #print(zero_point_next)
        
    zero_point_next = int(zero_point_next)
    
    return scale_next, zero_point_next

In [None]:
def quantize_tensor(x, num_bits, min_val=None, max_val=None):
    
    #x = torch.nan_to_num(x)
    if not min_val and not max_val:
        min_val, max_val = x.min(), x.max()
    else:
        x.clamp_(min_val, max_val)

    qmin = 0.
    qmax = 2.**num_bits - 1.

    scale, zero_point = calcScaleZeroPoint(min_val, max_val, num_bits)
    x = zero_point + x / scale
    x.clamp_(qmin, qmax).round_()
    
    return QTensor(tensor=x, scale=scale, zero_point=zero_point)

In [None]:
def dequantize_tensor(q_x):
    return q_x.scale * (q_x.tensor.float() - q_x.zero_point)

In [None]:
def calcScaleZeroPointSym(min_val, max_val,num_bits):
  
    # Calc Scale
    max_val = max(abs(min_val), abs(max_val))
    qmax = 2.**(num_bits-1) - 1.
    
    scale = max_val / qmax
    
    return scale, 0

In [None]:
def quantize_tensor_sym(x, num_bits, min_val=None, max_val=None):
    #x = torch.nan_to_num(x)
    if not min_val and not max_val:
        min_val, max_val = x.min(), x.max()
    else:
        x.clamp_(min_val, max_val)
        max_val = max(abs(min_val), abs(max_val))
    
    qmax = 2.**(num_bits-1) - 1.

    scale = max_val / qmax   

    x = x/scale

    x.clamp_(-qmax, qmax).round_()
    return QTensor(tensor=x, scale=scale, zero_point=0)

In [None]:
def dequantize_tensor_sym(q_x):
    return q_x.scale * (q_x.tensor.float())

### Rework Forward pass of Linear and Conv Layers to support Quantisation

In [None]:
def quantizeLayer(x, layer, stat, scale_x, zp_x, vis=False, axs=None, X=None, Y=None, sym=False, num_bits=8):
    # for both conv and linear layers
    
    # cache old values
    W = layer.weight.data
    B = layer.bias.data
    
    # WEIGHTS SIMULATED QUANTISED
    
    # quantise weights, activations are already quantised
    if sym:
        w = quantize_tensor_sym(layer.weight.data,num_bits=num_bits)
        b = quantize_tensor_sym(layer.bias.data,num_bits=num_bits)
    else:
        w = quantize_tensor(layer.weight.data, num_bits=num_bits)
        b = quantize_tensor(layer.bias.data, num_bits=num_bits)
        
    layer.weight.data = w.tensor.float()
    layer.bias.data = b.tensor.float()
    
    ## END WEIGHTS QUANTISED SIMULATION
    
    if vis:
        axs[X,Y].set_xlabel("Visualising weights of layer: ")
        visualise(layer.weight.data, axs[X,Y])
        
    # QUANTISED OP, USES SCALE AND ZERO POINT TO DO LAYER FORWARD PASS. (How does backprop change here ?)
    # This is Quantisation Arithmetic
    
    scale_w = w.scale
    zp_w = w.zero_point
    scale_b = b.scale
    zp_b = b.zero_point
    
    if sym:
        scale_next, zero_point_next = calcScaleZeroPointSym(min_val=stat['min'], max_val=stat['max'], num_bits=num_bits)
    else:
        scale_next, zero_point_next = calcScaleZeroPoint(min_val=stat['min'], max_val=stat['max'], num_bits=num_bits)
        
    # Preparing input by saturating range to num_bits range.
    if sym:
        x_ = x.float()
        layer.weight.data = ((scale_x * scale_w) / scale_next)*(layer.weight.data)
        layer.bias.data = (scale_b/scale_next)*(layer.bias.data)
    else:
        x_ = x.float() - zp_x
        layer.weight.data = ((scale_x * scale_w) / scale_next)*(layer.weight.data - zp_w)
        layer.bias.data = (scale_b/scale_next)*(layer.bias.data + zp_b)

    # All (Fake) int computation
    if sym:  
        x = (layer(x_)) 
        qmin = -2.**(num_bits -1)
        qmax = 2.**(num_bits -1) - 1
    else:
        x = (layer(x_)) + zero_point_next
        qmin = 0
        qmax = 2.**(num_bits) - 1
        
    # cast to int
    x.clamp_(qmin, qmax).round_()
    
    # Perform relu too
    x = F.leaky_relu(x) #?
    #x = F.relu(x)
    
    # Reset weights for next forward pass
    layer.weight.data = W
    layer.bias.data = B
    
    return x, scale_next, zero_point_next

### Get Stats for Quantising Activations of Network.

This is done by running the network with around 1000 examples and getting the average min and max activation values before and after each layer.

In [None]:
# Get Min and max of x tensor, and stores it
def updateStats(x, stats, key):
    max_val, _ = torch.max(x, dim=1)
    min_val, _ = torch.min(x, dim=1)
    
    # add ema calculation
    
    if key not in stats:
        stats[key] = {"max": torch.max(max_val).item(), "min": torch.min(min_val).item(), "total": 1}
        
    else:
        stats[key]['max'] = max(stats[key]['max'],torch.max(max_val).item())
        stats[key]['min'] = min(stats[key]['min'],torch.min(max_val).item())
        stats[key]['total'] += 1
        
    weighting = 2.0 / (stats[key]['total']) + 1
    
    if 'ema_min' in stats[key]:
        stats[key]['ema_min'] = weighting*(min_val.mean().item()) + (1- weighting) * stats[key]['ema_min']
    else:
        stats[key]['ema_min'] = weighting*(min_val.mean().item())
        
    if 'ema_max' in stats[key]:
        stats[key]['ema_max'] = weighting*(max_val.mean().item()) + (1- weighting) * stats[key]['ema_max']
    else:
        stats[key]['ema_max'] = weighting*(max_val.mean().item())
    
    
    return stats

In [None]:
# Reworked Forward Pass to access activation Stats through updateStats function
def gatherActivationStats(model, x, stats):
    
    stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'conv1')
    
    x = F.relu(model.conv1(x))
    
    x = F.max_pool2d(x, 2, 2)
    
    stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'conv2')
    
    x = F.relu(model.conv2(x))
    
    x = F.max_pool2d(x, 2, 2)
    
    x = x.view(-1, 1250)
    
    stats = updateStats(x, stats, 'fc1')
    
    x = F.relu(model.fc1(x))
    
    stats = updateStats(x, stats, 'fc2')
    
    x = model.fc2(x)
    
    stats = updateStats(x, stats, 'out')

    return stats

In [None]:
# Entry function to get stats of all functions.
def gatherStats(model, test_loader):
    device = 'cuda'    
    model.eval()
    test_loss = 0
    correct = 0
    stats = {}
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            stats = gatherActivationStats(model, data, stats)
    
    final_stats = {}
    for key, value in stats.items():
        final_stats[key] = { "max" : value["max"], "min" : value["min"], "ema_min": value["ema_min"], "ema_max": value["ema_max"] }
    return final_stats

### Forward Pass for Quantised Inference

In [None]:
def quantForward(model, x, stats, vis=False, axs=None, sym=False, num_bits=8):
    X = 0
    Y = 0
    # Quantise before inputting into incoming layers
    if sym:
        x = quantize_tensor_sym(x, min_val=stats['conv1']['min'], max_val=stats['conv1']['max'], num_bits=num_bits)
    else:
        x = quantize_tensor(x, min_val=stats['conv1']['min'], max_val=stats['conv1']['max'], num_bits=num_bits)

    if vis:
        axs[X,Y].set_xlabel('Entry into network, input distribution visualised below: ')
        visualise(x.tensor, axs[X,Y])
        
    x, scale_next, zero_point_next = quantizeLayer(x.tensor, model.conv1, stats['conv2'], x.scale, x.zero_point, vis, axs, X=X, Y=Y+1, sym=sym, num_bits=num_bits)

    x = F.max_pool2d(x, 2, 2)
  
    if vis:
        axs[X,Y+2].set_xlabel('Output after conv1 visualised below: ')
        visualise(x,axs[X,Y+2])

    x, scale_next, zero_point_next = quantizeLayer(x, model.conv2, stats['fc1'], scale_next, zero_point_next, vis, axs, X=X, Y=Y+3, sym=sym, num_bits=num_bits)

    x = F.max_pool2d(x, 2, 2)

    if vis:
        axs[X,Y+4].set_xlabel('Output after conv2 visualised below: ')
        visualise(x,axs[X,Y+4])

    x = x.view(-1, 1250)

    x, scale_next, zero_point_next = quantizeLayer(x, model.fc1, stats['fc2'], scale_next, zero_point_next, vis, axs, X=X+1, Y=0, sym=sym, num_bits=num_bits)

    if vis:
        axs[X+1,Y+1].set_xlabel('Output after fc1 visualised below: ')
        visualise(x,axs[X+1,Y+1])
  
    x, scale_next, zero_point_next = quantizeLayer(x, model.fc2, stats['out'], scale_next, zero_point_next, vis, axs, X=X+1, Y=Y+2, sym=sym, num_bits=num_bits)
    
    if vis:
        axs[X+1,Y+3].set_xlabel('Output after fc2 visualised below: ')
        visualise(x,axs[X+1,Y+3])
        
    # Back to dequant for final layer
    if sym:
        x = dequantize_tensor_sym(QTensor(tensor=x, scale=scale_next, zero_point=zero_point_next))
    else:
        x = dequantize_tensor(QTensor(tensor=x, scale=scale_next, zero_point=zero_point_next))

    if vis:
        axs[X+1,Y+4].set_xlabel('Output after fc2 but dequantised visualised below: ')
        visualise(x,axs[X+1,Y+4])

    return F.log_softmax(x, dim=1)

In [None]:
def testQuant(model, test_loader, quant=False, stats=None, sym=False, num_bits=8):
    device = 'cuda'
    
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            if quant:
                output = quantForward(model, data, stats, sym=sym, num_bits=num_bits)
            else:
                output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

## Get Accuracy for Quantized Model

In [None]:
import copy
q_model = copy.deepcopy(model)

In [None]:
kwargs = {'num_workers': 1, 'pin_memory': True}

transform = transforms.Compose(
        [transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

testset = datasets.CIFAR10(root='./dataCifar', train=False,
                                        download=True, transform=transform)
    
test_loader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True, **kwargs)

Files already downloaded and verified


In [None]:
# Non Quantized
testQuant(q_model, test_loader, quant=False)


Test set: Average loss: 0.9150, Accuracy: 6840/10000 (68%)



### Gather Stats of Activations

In [None]:
stats = gatherStats(q_model, test_loader)
print(stats)

{'conv1': {'max': 1.0, 'min': -1.0, 'ema_min': -0.8961577837067178, 'ema_max': 0.906515912253188}, 'conv2': {'max': 10.053322792053223, 'min': 0.0, 'ema_min': 0.0, 'ema_max': 4.797752144011583}, 'fc1': {'max': 13.948284149169922, 'min': 0.0, 'ema_min': 0.0, 'ema_max': 5.490657327288742}, 'fc2': {'max': 8.615355491638184, 'min': 0.0, 'ema_min': 0.0, 'ema_max': 3.403652658938573}, 'out': {'max': 16.977933883666992, 'min': -8.503494262695312, 'ema_min': -4.213035598256015, 'ema_max': 5.073444108768709}}


In [None]:
testQuant(q_model, test_loader, quant=True, stats=stats, sym=False, num_bits=8)


Test set: Average loss: 3.6772, Accuracy: 4785/10000 (48%)



In [None]:
testQuant(q_model, test_loader, quant=True, stats=stats, sym=False, num_bits=4)


Test set: Average loss: 4.4507, Accuracy: 4093/10000 (41%)



In [None]:
testQuant(q_model, test_loader, quant=True, stats=stats, sym=False, num_bits=2)


Test set: Average loss: 10.1600, Accuracy: 1746/10000 (17%)

