# Quantising a MobileNetV1 with the PACT algorithm

This notebook shows how to create an integerised [MobileNetV1](https://arxiv.org/pdf/1704.04861.pdf) using the QuantLib package.


In [1]:
from __future__ import annotations

from typing import NamedTuple, List, Union, Optional


## Part 1: creating and evaluating a floating-point network

### Step 1: check the computing infrastructure

Depending on the hardware at our disposal, we will make different training and testing processes choices.


In [7]:
import multiprocessing
import torch

n_cpus = multiprocessing.cpu_count()
n_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0
print(f"Available system CPU(s): {n_cpus}.")
print(f"Available system GPU(s): {n_gpus}.")

device = torch.device(torch.cuda.current_device()) if (n_gpus > 0) else torch.device('cpu')


Available system CPU(s): 12.
Available system GPU(s): 0.


### Step 2: load data into PyTorch

Every problem in supervised learning requires a data set.
We can partition the data points of this data set into three categories:
* **training** points: they are available at training time and we can use them to update the model's parameters;
* **validation** points: they are available at training time, we can not use them to update the parameters, but we can use them to measure the performance of the model throughout the learning process;
* **test** points: they are not available at training time but only once the model is deployed in production.


In [8]:
from enum import Enum, auto, unique


# each supervised learning problem defines a training/validation/test partition of its data points
@unique
class DataSetPartition(Enum):
    TRAINING   = auto()
    VALIDATION = auto()
    TEST       = auto()

    @classmethod
    def canonicalise(cls, partition_spec: DataSetPartitionSpec) -> DataSetPartition:

        # validate specification type
        if not isinstance(partition_spec, (DataSetPartition, str,)):
            raise TypeError

        if isinstance(partition_spec, DataSetPartition):
            partition = partition_spec
        
        else:  # `isinstance(partition, str)`

            partition_spec = partition_spec.upper()
            if partition_spec in {'TRAINING', 'TRAIN'}:
                partition = cls['TRAINING']
            elif partition_spec in {'VALIDATION', 'VALID'}:
                partition = cls['VALIDATION']
            elif partition_spec in {'TEST', 'TESTING'}:
                partition = cls['TEST']
            else:
                raise ValueError
        
        return partition

    
# define the ways in which a user can specify data set partitions
DataSetPartitionSpec = Union[DataSetPartition, str]


PyTorch represents data points using the `torch.Tensor` data structure.
PyTorch uses a specific pipeline to transform files stored on disk into mini-batches of `Tensor`s.
This pipeline consists of four stages.
* Define a function to transform a Python object into a corresponding collection of `Tensor`s. For instance, we can map a labelled image object to a pair consisting of an input tensor and a label tensor. These functions are called *transforms* and can be composed to describe complex pre-processing transformations (see `torchvision.transforms.Compose`).
* Define a `torch.utils.data.Dataset` to index the files stored on disk, specify how to load each file into a Python object, and specify how to map such object to a corresponding collection of `Tensor`s.
* Define a `torch.utils.data.Sampler` that can sample and return mini-batches from a given list of integers. This batching can be performed with or without permuting the list. If the list is permuted, it can be permuted once (without repetition) or in-between each sampling (with repetition).
* Define a `torch.utils.data.DataLoader` specifying the mini-batches size and the number of worker threads that should be used to load data point files from disk. `DataLoader`s work as follow:
  * the `DataLoader` queries the `Sampler` for a list of indices;
  * the `DataLoader` distributes these indices to the worker threads;
  * each thread applies the `transforms` specified by the `Dataset` to the files corresponding to the received indices; since `transforms` should be applied individually to each data point, `DataLoader`s apply a *map pattern* to parallelise the work amongst multiple workers;
  * each thread returns the (loaded and pre-processed) `Tensor` data points to the `DataLoader`;
  * the `DataLoader` collates these objects into a mini-batch of `Tensor`s, and returns these mini-batches.

Let's start by defining ImageNet-specific transforms.


In [9]:
from torchvision.transforms import RandomHorizontalFlip, RandomResizedCrop
from torchvision.transforms import Resize, CenterCrop
from torchvision.transforms import ToTensor, Normalize, Lambda
from torchvision.transforms import Compose


ImageNetStats = \
    {
        'normalise':
            {
                'mean': (0.485, 0.456, 0.406),
                'std':  (0.229, 0.224, 0.225)
            },
        'quantise':
            {
                'min':   -2.1179039478,  # computed on the normalised images of the validation partition
                'max':   2.6400001049,   # computed on the normalised images of the validation partition
                'scale': 0.020625000819563866
            }
    }


class ImageNetNormalise(Normalize):
    def __init__(self):
        super(ImageNetNormalise, self).__init__(**ImageNetStats['normalise'])
        

class ImageNetIntegerise(Lambda):
    def __init__(self):
        INT8_MIN = -2**(8-1)
        INT8_MAX = 2**(8-1) - 1
        image_scale = ImageNetStats['quantise']['scale']
        super(ImageNetIntegerise, self).__init__(lambda x: torch.clip((x / image_scale).floor(), INT8_MIN, INT8_MAX))


class ImageNetTransform(Compose):

    def __init__(self, partition_spec: DataSetPartitionSpec, image_size: int = 224, integerise: bool = False):

        # validate arguments
        RESIZE_SIZE = 256
        if not (image_size <= RESIZE_SIZE):
            raise ValueError  # otherwise, we can not crop the resized image to the desired size

        partition = DataSetPartition.canonicalise(partition_spec)

        if partition is DataSetPartition['TRAINING']:
            transforms = [RandomHorizontalFlip(),
                          RandomResizedCrop(image_size)]
        else:
            transforms = [Resize(RESIZE_SIZE),
                          CenterCrop(image_size)]

        transforms += [ToTensor(),
                       ImageNetNormalise(),
                       ImageNetIntegerise()]
        
        if not integerise:
            transforms += [Lambda(lambda x: x * ImageNetStats['quantise']['scale'])]  # return a fake-quantised `Tensor`

        super(ImageNetTransform, self).__init__(transforms)


We can use a [factory pattern](https://en.wikipedia.org/wiki/Factory_method_pattern) to create training and validation `DataLoader`s while reducing code duplication.

In [10]:
from collections import OrderedDict
import os
import torchvision


class ImageNetDataLoaderFactory(object):
    
    def __init__(self, path_data: str):
        
        if not os.path.isdir(path_data):
            raise FileNotFounderror  # missing ImageNet data folder

        super(ImageNetDataLoaderFactory, self).__init__()

        self._partition_to_subfolder = OrderedDict([
            (DataSetPartition['TRAINING'], os.path.join(path_data, 'train')),
            (DataSetPartition['VALIDATION'], os.path.join(path_data, 'val')),
        ])

        
    def get_dataset(self,
                    partition_spec: DataSetPartitionSpec,
                    transform:      torchvision.transforms.Compose) -> torch.utils.data.Dataset:
        partition = DataSetPartition.canonicalise(partition_spec)
        return torchvision.datasets.ImageFolder(self._partition_to_subfolder[partition], transform)
    
    @staticmethod
    def get_sampler(partition_spec: DataSetPartitionSpec,
                    dataset: torch.utils.data.Dataset) -> torch.utils.data.Sampler:
        partition = DataSetPartition.canonicalise(partition_spec)
        return torch.utils.data.RandomSampler(dataset) if (partition is DataSetPartition['TRAINING']) else torch.utils.data.SequentialSampler(dataset)
    
    def get_dataloader(self,
                       partition_spec: DataSetPartitionSpec,
                       transform:      torchvision.transforms.Compose,
                       batch_size:     int,
                       num_workers:    int = 1) -> torch.utils.data.DataLoader:

        partition = DataSetPartition.canonicalise(partition_spec)

        dataset = self.get_dataset(partition, transform)
        sampler = ImageNetDataLoaderFactory.get_sampler(partition, dataset)
        loader = torch.utils.data.DataLoader(dataset=dataset,
                                             sampler=sampler,
                                             batch_size=batch_size,
                                             num_workers=num_workers)

        return loader


We are now ready to create our training and validation `DataLoader`s.

In [11]:
# create the `DataLoader` factory
path_data = os.path.join(os.curdir, 'data')
loader_factory = ImageNetDataLoaderFactory(path_data)

per_gpu_batch_size = 64
batch_size = max(1, n_gpus) * per_gpu_batch_size  # `nn.DataParallel` will revert this dispatching `per_gpu_batch_size` items to each GPU

# create the training `DataLoader`
train_transform = ImageNetTransform('train')
train_loader = loader_factory.get_dataloader('train', train_transform, batch_size, num_workers=n_cpus)

# create the validation `DataLoader`
valid_transform = ImageNetTransform('valid')
valid_loader = loader_factory.get_dataloader('valid', valid_transform, batch_size, num_workers=n_cpus)


NameError: name 'FileNotFounderror' is not defined

### Step 3: create a floating-point PyTorch network

PyTorch represents deep neural networks as `torch.nn.Module`s.

Since deep neural networks can be (and are often) modelled as function compositions, `Module`s have been designed to be composed to create complex functions, i.e., complex networks.
For this reason, we can distinguish between:
* *atomic* `Module`s; examples are linear operations, batch normalisations, and activation operations;
* *container* `Module`s; containers are used to express layers (usually concatenating linear and activation operations, with possibly a batch normalisation in-between) and blocks of layers.

MobileNetV1 is a sequential composition of `Module`s.


In [12]:
import torch.nn as nn


CONFIGS = OrderedDict([
    ('STANDARD', [
        ( 2, 1),
        ( 4, 2),
        ( 4, 1),
        ( 8, 2),
        ( 8, 1),
        (16, 2),
        (16, 1),
        (16, 1),
        (16, 1),
        (16, 1),
        (16, 1),
        (32, 2),
        (32, 1)
    ])
])


ACTIVATIONS = ('relu', 'relu6',)


class MobileNetV1(nn.Module):

    def __init__(self,
                 config:     str,
                 capacity:   float = 1.0,
                 activation: str = 'ReLU',
                 n_classes:  int = 1000,
                 seed:       int = -1):

        # validate inputs
        config = config.upper()  # canonicalise
        if config not in CONFIGS.keys():
            raise ValueError  # invalid configuration
            
        if not (0.0 < capacity <= 1.0):
            raise ValueError  # capacity must be a positive, compressive (i.e., not greater than one) scaling factor

        activation = activation.lower()  # canonicalise
        if activation not in ACTIVATIONS:
            raise ValueError  # invalid activation function
        if activation == 'relu':
            activation_class = nn.ReLU
        else:  # activation == 'relu6':
            activation_class = nn.ReLU6

        super(MobileNetV1, self).__init__()

        # build the network
        base_width      = int(32 * capacity)
        self.pilot      = MobileNetV1.make_pilot(base_width, activation_class)
        self.features   = MobileNetV1.make_features(config, base_width, activation_class)
        self.avgpool    = MobileNetV1.make_avgpool()
        self.classifier = MobileNetV1.make_classifier(config, base_width, n_classes)

        self._initialize_weights(seed)

    @staticmethod
    def make_standard_convolution_layer(in_channels:      int,
                                        out_channels:     int,
                                        stride:           Union[int, Tuple[int, ...]],
                                        activation_class: type) -> nn.Sequential:

        modules = []

        modules += [nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), stride=stride, padding=(1, 1), bias=False)]
        modules += [nn.BatchNorm2d(out_channels)]
        modules += [activation_class(inplace=True)]

        return nn.Sequential(*modules)

    @staticmethod
    def make_depthwise_separable_convolution_block(in_channels:      int,
                                                   out_channels:     int,
                                                   stride:           Union[int, Tuple[int, ...]],
                                                   activation_class: type) -> nn.Sequential:

        modules = []

        # depthwise
        modules += [nn.Conv2d(in_channels, in_channels, kernel_size=(3, 3), stride=stride, padding=(1, 1), groups=in_channels, bias=False)]
        modules += [nn.BatchNorm2d(in_channels)]
        modules += [activation_class(inplace=True)]

        # pointwise
        modules += [nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=False)]
        modules += [nn.BatchNorm2d(out_channels)]
        modules += [activation_class(inplace=True)]

        return nn.Sequential(*modules)

    @staticmethod
    def make_pilot(base_width:       int,
                   activation_class: type) -> nn.Sequential:

        in_channels = 3
        out_channels = base_width

        return MobileNetV1.make_standard_convolution_layer(in_channels=in_channels,
                                                           out_channels=out_channels,
                                                           stride=2,  # we start with a spatial down-sampling
                                                           activation_class=activation_class)

    @staticmethod
    def make_features(config:           str,
                      base_width:       int,
                      activation_class: type) -> nn.Sequential:

        modules = []

        in_channels = base_width
        for n_channels_multiplier, stride in CONFIGS[config]:
            out_channels = base_width * n_channels_multiplier
            modules += [MobileNetV1.make_depthwise_separable_convolution_block(in_channels=in_channels,
                                                                               out_channels=out_channels,
                                                                               stride=stride,
                                                                               activation_class=activation_class)]
            in_channels = out_channels

        return nn.Sequential(*modules)

    @staticmethod
    def make_avgpool() -> nn.AdaptiveAvgPool2d:
        return nn.AdaptiveAvgPool2d((1, 1))

    @staticmethod
    def make_classifier(config:     str,
                        base_width: int,
                        n_classes:  int) -> nn.Linear:

        last_n_channels_multiplier = CONFIGS[config][-1][0]
        in_channels = last_n_channels_multiplier * base_width
        in_features = in_channels * 1 * 1

        return nn.Linear(in_features=in_features, out_features=n_classes)

    def forward(self, x):

        x = self.pilot(x)
        x = self.features(x)
        x = self.avgpool(x)

        x = x.view(x.size(0), -1)

        x = self.classifier(x)

        return x

    def _initialize_weights(self, seed: int = -1):

        if seed >= 0:
            torch.manual_seed(seed)

        for m in self.modules():

            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)


Let's create our MobileNetV1.

In [13]:
# create a MobileNetV1
config = 'standard'
capacity = 0.75
activation = 'relu'  # other option: `relu6`

mnv1 = MobileNetV1(config=config, capacity=capacity, activation=activation)


In the introduction of this sub-section, we observed that PyTorch networks are modelled as hierarchies of `Module`s.
QuantLib provides a useful `lightweight` sub-package to traverse these hierarchies and give an overview of its atoms (i.e., non-container `Module`s).


In [14]:
import quantlib.editing.graphs as qg

mnv1_lw = qg.lw.quantlib_traverse(mnv1)
mnv1_lw.show()



      0.pilot	Conv2d(3, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      1.pilot	BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      2.pilot	ReLU(inplace=True)
 0.0.features	Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=24, bias=False)
 1.0.features	BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
 2.0.features	ReLU(inplace=True)
 3.0.features	Conv2d(24, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
 4.0.features	BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
 5.0.features	ReLU(inplace=True)
 0.1.features	Conv2d(48, 48, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=48, bias=False)
 1.1.features	BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
 2.1.features	ReLU(inplace=True)
 3.1.features	Conv2d(48, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
 4.1.features	BatchNorm2d(96, eps=1e-05, m

A fundamental functionality of PyTorch is its convenient interface for using GPUs.
Before we explain how this interface works, we need to provide some additional information about the workings of PyTorch `Tensor`s.

Each PyTorch `Tensor` is a wrapper abstraction around a *payload* array.
Apart from this payload, a `Tensor` has other attributes and methods, most of which operate on the payload (we call these *payload methods*).
One of these attributes is `device`, which indicates whether the payload is stored on the main memory of the computing system (i.e., the one managed directly by the CPU) or on the memory of some accelerator attached to the computing system (e.g., a GPU). The C++ backend of `Tensor`s includes several implementations of payload methods, one for each computing device that might be available on our computing system (e.g., CPU vs. GPU).
Depending on the `device` attributes of the `Tensor`s involved in a given operation, PyTorch's runtime engine dispatches the code to the correct version of the payload method.

Most `Module`s have parameters or hyper-parameters, but moving them individually and manually to the correct device memory can become cumbersome and is error-prone.
Thus, PyTorch `Module`s expose a `to` method which can be used to automatically move all the parameters and hyper-parameters to the correct device memory.
Due to the hierarchical nature of `Module` compositions, calling `to` on a container `Module` will invoke the same call on its children `Module`s.

If more than one GPU is available on our system, it is possible to wrap a `Module` into an `nn.DataParallel` object.
At runtime, this object will automatically replicate the `Module` and map a replica to each available GPU; then, it will partition mini-batches to distribute data points evenly amongst all the available GPUs.
`DataParallel` objects are a convenient abstraction to exploit all the computational power available on your computing system.

However, note that QuantLib's editing functionalities only work on `Module`s.
Since `DataParallel` objects are wrappers around `Module`s but not `Module` objects, we prefer to keep at least one symbolic handle to the main `Module` object.
In this way, we will be able to edit the underlying `Module` objects.


In [15]:
def maybe_migrate_to_gpu(network: nn.Module,
                         device:  torch.device,
                         n_gpus:  int) -> Tuple[nn.Module, Union[nn.Module, nn.DataParallel]]:
    """If GPUs are avaiable, migrate the network there for better time performance."""

    if n_gpus > 0:
        network = network.to(device=device)  # move the model parameters to the lead GPU
    
    if n_gpus > 1:
        maybenndp_network = nn.DataParallel(network)  # at runtime, the model will be replicated on each available GPU
    else:
        maybenndp_network = network
    
    return network, maybenndp_network


Let's migrate the parameters and hyper-parameters of our MobileNetV1 to GPU.

In [16]:
mnv1, maybenndp_mnv1 = maybe_migrate_to_gpu(mnv1, device, n_gpus)

### Step 4: evaluate the performance of a raw network

As in most experimental sciences, it is crucial to perform elementary experiments also in machine learning.
These experiments are important to:
* validate assumptions and expectations that, if violated, could invalidate all the following experiments;
* establish baselines against which we can compare future results.

Given that in the ImageNet validation set the 1000 classes are equally represented, we expect that the accuracy of an untrained MobileNetV1 should be around 0.1%.


In [17]:
class Label(NamedTuple):
    true:      int
    predicted: int


class Evaluation(OrderedDict):

    def __setitem__(self, input_id: int, label: Label):
        if not isinstance(input_id, int):
            raise TypeError
        if not isinstance(label, Label):
            raise TypeError

        super(Evaluation, self).__setitem__(input_id, label)

    @property
    def correct(self) -> int:
        return sum((label.true == label.predicted) for label in self.values())

    @property
    def accuracy(self) -> float:
        return 100.0 * (float(self.correct) / len(self))

    def compare(self, other: Evaluation) -> float:
        """Return the percentage of matching predictions."""

        if len(set(self.keys()).symmetric_difference(set(other.keys()))) > 0:
            raise ValueError  # can only compare evaluations carried out on the same data points

        # else, I proceed with the comparison
        matched: int = 0
        for input_id, label in self.items():
            other_label = other[input_id]
            if label.predicted == other_label.predicted:
                matched += 1

        return 100.0 * (float(matched) / len(self))

    
def evaluate_network(loader:  torch.utils.data.DataLoader,
                     network: Union[nn.Module, nn.DataParallel],
                     device:  torch.device) -> Evaluation:

    if not isinstance(loader.sampler, torch.utils.data.SequentialSampler):
        raise ValueError  # the order of the data points is not deterministic, and the input IDs lose their meaning
    
    evaluation = Evaluation()
    base_input_id: int = 0

    for x, y_true in loader:
        
        x = x.to(device=device)
        y_true = y_true.to(device=device)

        y_pred = torch.argmax(network(x), dim=1)
        
        for i, (yt, yp) in enumerate(zip(y_true.flatten(), y_pred.flatten())):
            evaluation[base_input_id + i] = Label(int(yt), int(yp))
        base_input_id += len(x)
        
    return evaluation


Let's evaluate the network performance.

In [18]:
# set the network in evaluation mode to "freeze" the parameters of batch normalisations
maybenndp_mnv1.eval()

# evaluate the network on the validation set
mnv1_perf = evaluate_network(valid_loader, maybenndp_mnv1, device)
print("Accuracy (floating-point, untrained): {:6.2f}%.".format(mnv1_perf.accuracy))

# restore the training mode
maybenndp_mnv1.train()
pass


NameError: name 'valid_loader' is not defined

The accuracy is in line with our expectations: a positive sanity check.

## Part 2: training and evaluating a fake-quantised network

### Step 1: retrieve a pre-trained floating-point network

In some cases, we can apply quantisation algorithms to pre-trained floating-point networks.
These algorithms can be classified in:
* **post-training quantisation (PTQ)** algorithms, which do not need to run any gradient descent iteration or apply any parameter updates;
* **quantisation-aware fine-tuning (QAFT)** algorithms, which are applications of *quantisation-aware training (QAT)* algorithms lasting at most a few epochs.

We create a new MobileNetV1 and load such a pre-trained model to speed up our work.


In [19]:
# create a MobileNetV1
config = 'standard'
capacity = 0.75
activation = 'relu'  # other option: `relu6`

mnv1 = MobileNetV1(config=config, capacity=capacity, activation=activation)
mnv1, maybenndp_mnv1 = maybe_migrate_to_gpu(mnv1, device, n_gpus)

# get the path to the floating-point checkpoint
path_logs = os.path.join(os.curdir, 'logs')
fp_checkpoint_filename = '_'.join(['MNv1', str(capacity), activation]) + '.ckpt'
path_fp_checkpoint = os.path.join(path_logs, fp_checkpoint_filename)
if not os.path.isfile(path_fp_checkpoint):
    raise FileNotFoundError

# load the pre-trained parameters into the network object
pretrained_state_dict = torch.load(path_fp_checkpoint)
mnv1.load_state_dict(pretrained_state_dict)


FileNotFoundError: 

Let's evaluate the network performance.

In [20]:
# set the network in evaluation mode to "freeze" the parameters of batch normalisations
maybenndp_mnv1.eval()

# evaluate the network on the validation set
mnv1_perf = evaluate_network(valid_loader, maybenndp_mnv1, device)
print("Accuracy (floating-point, trained): {:6.2f}%.".format(mnv1_perf.accuracy))

# restore the training mode
maybenndp_mnv1.train()
pass


NameError: name 'valid_loader' is not defined

The accuracy is in line with the one reported on the [original paper](https://arxiv.org/pdf/1704.04861.pdf) for a MobileNetV1 whose capacity is set to 0.75.

### Step 2: perform the float-to-fake (F2F) conversion

QuantLib's `editing` package implements the building blocks of a rudimental compiler to transform floating-point `Module`s into quantised neural networks.
This package consists of two sub-packages:
* `graphs`, extending PyTorch's `nn` and `fx` namespaces;
* `editing`, implementing the computational graph annotation and rewriting functionalities.

The first step towards quantising our MobileNetV1 is replacing its composing `Module`s with counterparts that support quantisation.
In QuantLib, these counterparts are `_QModule`s.

As a first step, we need to trace our floating-point network.


In [21]:
# trace
mnv1.eval()  # remember to freeze parameters, since the `Editor`s might operate with them
mnv1fp = qg.fx.quantlib_symbolic_trace(root=mnv1)


Now that we have an `fx.GraphModule` object, we can pass it to the tool performing the so-called **float-to-fake (F2F)** conversion flow.

In this case, we aim for a quantised network that uses signed 8-bit integers for the weight arrays and unsigned 8-bit integers for the feature arrays.
The chosen QAT algorithm is [*parametrised clipping activation (PACT)*](https://proceedings.mlsys.org/paper/2019/file/006f52e9102a8d3be2fe5614f42ba989-Supplemental.pdf).


In [22]:
import quantlib.editing.editing as qe

f2fconverter = qe.f2f.F2F8bitPACTConverter()
mnv1fq_uninit = f2fconverter(mnv1fp)


Let's visually inspect whether the conversion was successful.

In [23]:
mnv1fq_uninit_lw = qg.lw.quantlib_traverse(root=mnv1fq_uninit)
mnv1fq_uninit_lw.show()



      0.pilot	PACTConv2d(3, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      1.pilot	BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      2.pilot	PACTReLU(inplace=True)
 0.0.features	PACTConv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=24, bias=False)
 1.0.features	BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
 2.0.features	PACTReLU(inplace=True)
 3.0.features	PACTConv2d(24, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
 4.0.features	BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
 5.0.features	PACTReLU(inplace=True)
 0.1.features	PACTConv2d(48, 48, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=48, bias=False)
 1.1.features	BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
 2.1.features	PACTReLU(inplace=True)
 3.1.features	PACTConv2d(48, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
 4.1.f

Now we are going to prepare all the ingredients to train and validate our fake-quantised network:
* the loss function; in PyTorch, loss functions are implemented as `Module`s;
* the optimiser; in PyTorch, optimisers are implemented as `torch.optim.Optimizer` objects; the responsibility of `Optimizer`s is updating the parameters, **not** performing gradient descent (which is a prerogative of the [*autograd* engine](https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html)).

Some QAT algorithms, such as PACT, might require specific `Optimizer`s.
In these cases, the corresponding QuantLib `algorithms` sub-package must provide their implementation.


In [24]:
# if GPUs are available, migrate the network
mnv1fq_uninit, maybenndp_mnv1fq_uninit = maybe_migrate_to_gpu(mnv1fq_uninit, device, n_gpus)

# create the loss function
loss_fn = nn.CrossEntropyLoss()

# create the optimiser; since we use PACT, we need an `Optimizer` capable of updating the clipping bounds independently of other parameters
import quantlib.algorithms as qa
optimiser = qa.qalgorithms.qatalgorithms.pact.PACTSGD(mnv1fq_uninit, pact_decay=0.001, lr=0.001, momentum=0.9)


Multiple iterations through the training data set might be required to bring the model to convergence.
Since we also want to validate the performance of our model on the validation set at each training iteration, we define convenient functions to run individual training and validation epochs.


In [25]:
def train_one_epoch(loader:    torch.utils.data.DataLoader,
                    network:   Union[nn.Module, nn.DataParallel],
                    device:    torch.device,
                    loss_fn:   nn.Module,
                    optimiser: torch.optim.Optimiser,
                    verbose:   bool = False):
    
    network.train()
    
    # statistical performance counters  # TODO: define a `StatisticalPerformanceCounters` object
    n_points:   int = 0
    correct:    int = 0
    total_loss: float = 0.0
    
    for batch_id, (x, y_true) in enumerate(loader):
        
        # cast data points to the network's device
        x = x.to(device)
        y_true = y_true.to(device)

        # forward pass
        y_pred = network(x)
        loss = loss_fn(y_pred, y_true)
        
        # update performance counters
        n_points += len(x)
        correct += int(torch.sum(y_true == y_pred.argmax(dim=1)))
        total_loss = total_loss + (loss.item() * len(x))
        if verbose:
            print("Training batch [{:5d}/{:5d}] - Loss: {:8.3f} - Accuracy: {:6.2f}%".format(batch_id, len(loader), total_loss / n_points, 100.0 * (float(correct) / n_points)))
        
        # backward pass
        optimiser.zero_grad()  # clear old gradients
        loss.backward()        # compute new gradients
        optimiser.step()       # apply gradient descent step
        
        
def validate_one_epoch(loader:    torch.utils.data.DataLoader,
                       network:   Union[nn.Module, nn.DataParallel],
                       device:    torch.device,
                       loss_fn:   nn.Module,
                       verbose:   bool = False):
    
    network.eval()

    # statistical performance counters  # TODO: define a `StatisticalPerformanceCounters` object
    n_points:   int = 0
    correct:    int = 0
    total_loss: float = 0.0
    
    for batch_id, (x, y_true) in enumerate(loader):
        
        # cast data points to the network's device
        x = x.to(device)
        y_true = y_true.to(device)

        # forward pass
        y_pred = network(x)
        loss = loss_fn(y_pred, y_true)

        # update performance counters
        n_points += len(x)
        correct += int(torch.sum(y_true == y_pred.argmax(dim=1)))
        total_loss = total_loss + (loss.item() * len(x))
        if verbose:
            print("Validation batch [{:5d}/{:5d}] - Loss: {:8.3f} - Accuracy: {:6.2f}%".format(batch_id, len(loader), total_loss / n_points, 100.0 * (float(correct) / n_points)))


Before training a fake-quantised network, we want to initialise the hyper-parameters of quantisers to minimise the discrepancy between corresponding floating-point and fake-quantised arrays.

To achieve this purpose, we *observe* the statistics of the `Tensor`s passing through the floating-point network during a validation epoch.


In [26]:
# since we run this "observation" on CPU or a single GPU, we need to limit the batch size
warmup_valid_loader = loader_factory.get_dataloader('valid', valid_transform, per_gpu_batch_size, num_workers=n_cpus)

# set validation state
mnv1fq_uninit.eval()

# collect statistics about the floating-point `Tensor`s passing through the quantisers, so that we can better fit the quantisers' hyper-parameters
# start observing
for m in mnv1fq_uninit.modules():
    if isinstance(m, tuple(qa.qalgorithms.qatalgorithms.pact.NNMODULE_TO_PACTMODULE.values())):
        m.start_observing()
# collect statistics
validate_one_epoch(warmup_valid_loader, mnv1fq_uninit, device, loss_fn)
# stop observing
for m in mnv1fq_uninit.modules():
    if isinstance(m, tuple(qa.qalgorithms.qatalgorithms.pact.NNMODULE_TO_PACTMODULE.values())):
        m.stop_observing()

# restore training state
mnv1fq_uninit.train()
        
mnv1fq_init, maybenndp_mnv1fq_init = mnv1fq_uninit, maybenndp_mnv1fq_uninit  # now the quantisers' hyper-parameters are initialised


NameError: name 'loader_factory' is not defined

Note that this tuning operates locally (i.e., on pairs of corresponding arrays).
Therefore, the discrepancies will likely propagate through the network yielding poor performance.


In [27]:
# set the network in evaluation mode to "freeze" the parameters of batch normalisations
maybenndp_mnv1fq_init.eval()

# evaluate the network on the validation set
mnv1fq_perf = evaluate_network(valid_loader, maybenndp_mnv1fq_init, device)
print("Accuracy (fake-quantised, untrained): {:6.2f}%.".format(mnv1fq_perf.accuracy))

# restore the training mode
maybenndp_mnv1fq_init.train()
pass


NameError: name 'maybenndp_mnv1fq_init' is not defined

Let's proceed with a fine-tuning epoch.

In [28]:
fp_checkpoint_base_filename, extension = fp_checkpoint_filename.rsplit('.', 1)
fq_checkpoint_base_filename = '_'.join([fp_checkpoint_base_filename, 'FQ', 'uint8x', 'int8w'])
fq_checkpoint_filename = '.'.join([fq_checkpoint_base_filename, extension])

path_fq_checkpoint = os.path.join(path_logs, fq_checkpoint_filename)

if not os.path.isfile(path_fq_checkpoint):
    train_one_epoch(train_loader, maybenndp_mnv1fq_init, device, loss_fn, optimiser, verbose=True)
    torch.save(mnv1fq_init.state_dict(), path_fq_checkpoint)

mnv1fq_init.load_state_dict(torch.load(path_fq_checkpoint))


NameError: name 'train_loader' is not defined

What is the performance of our fake-quantised MobileNetV1?

In [24]:
# set the network in evaluation mode to "freeze" the parameters of batch normalisations
maybenndp_mnv1fq_init.eval()

# evaluate the network on the validation set
mnv1fq_init_perf = evaluate_network(valid_loader, maybenndp_mnv1fq_init, device)
print("Accuracy (fake-quantised, trained): {:6.2f}%.".format(mnv1fq_init_perf.accuracy))

# restore the training mode
maybenndp_mnv1fq_init.train()
pass


Accuracy (fake-quantised, trained):  68.54%.


Thanks to the careful initialisation of the quantisers' hyper-parameters, even a single fine-tuning epoch can suffice to recover from the accuracy drop due to F2F conversion.

## Part 3: integerising a fake-quantised network

A trained fake-quantised network is not a true integerised program.
To obtain such a network, we must apply the so-called **fake-to-true (F2T)** conversion.
An F2T conversion is a sequence of program transformations to turn a fake-quantised network into an integerised one.

At the current stage, QuantLib does not support fully-automatic integerisation of programs due to network-specific computation patterns that we have not yet broken down into atomic transforms.
Usually, we can find these network-specific patterns in the most downstream parts of a network; indeed, network *backbones* are general-purpose, whereas the output *heads* are task-specific and can create exotic computation patterns.

Before we proceed with F2T conversion, we need to specify how to rewrite the classifier of our MobileNetV1.


In [25]:
class MNv1Head(nn.Module):

    def __init__(self):
        super(MNv1Head, self).__init__()
        self.eps = qg.nn.EpsTunnel(torch.Tensor([1.0]))
        self.avg = nn.AdaptiveAvgPool2d((1, 1))
        self.lin = nn.Linear(1, 1, bias=True)

    def forward(self, x):
        x = self.eps(x)
        x = self.avg(x)
        x = x.view(x.size(0), 1)
        x = self.lin(x)
        return x


class MNv1HeadApplier(qe.editors.nnmodules.NNModuleApplier):

    def __init__(self, rn18headpattern: qe.editors.nnmodules.GenericNNModulePattern):
        super(MNv1HeadApplier, self).__init__(rn18headpattern)

    def _apply(self, g: fx.GraphModule, ap: qe.editors.nnmodules.NodesMap, id_: str) -> fx.GraphModule:

        name_to_match_node = self.pattern.name_to_match_node(nodes_map=ap)
        node_lin = name_to_match_node['lin']

        name_to_match_module = self.pattern.name_to_match_module(nodes_map=ap, data_gm=g)
        module_eps = name_to_match_module['eps']
        module_lin = name_to_match_module['lin']

        assert module_eps.eps_out.numel() == 1
        assert len(node_lin.all_input_nodes) == 1

        # create the new module
        new_target = id_
        new_module = nn.Linear(in_features=module_lin.in_features, out_features=module_lin.out_features, bias=module_lin.bias is not None)
        new_weight = module_lin.weight.data.detach().clone() * module_eps.eps_out
        new_module.weight.data = new_weight
        if module_lin.bias is not None:
            new_bias = module_lin.bias.data.detach().clone()
            new_module.bias.data = new_bias

        # add the requantised linear operation to the graph...
        g.add_submodule(new_target, new_module)
        linear_input = next(iter(node_lin.all_input_nodes))
        with g.graph.inserting_after(linear_input):
            new_node = g.graph.call_module(new_target, args=(linear_input,))
        node_lin.replace_all_uses_with(new_node)

        module_eps.set_eps_out(torch.ones_like(module_eps.eps_out))

        # ...and delete the old operation
        g.delete_submodule(node_lin.target)
        g.graph.erase_node(node_lin)

        return g


class MNv1HeadRewriter(qe.editors.nnmodules.NNModuleRewriter):

    def __init__(self):
        # create pattern
        rn18headwithcheckers = qe.editors.nnmodules.NNModuleWithCheckers(MNv1Head(), {})
        rn18headpattern = qe.editors.nnmodules.GenericNNModulePattern(qg.fx.quantlib_symbolic_trace, rn18headwithcheckers)
        # create matcher and applier
        finder = qe.editors.nnmodules.GenericGraphMatcher(rn18headpattern)
        applier = MNv1HeadApplier(rn18headpattern)
        # link pattern, matcher, and applier into the rewriter
        super(MNv1HeadRewriter, self).__init__('MNv1HeadRewriter', rn18headpattern, finder, applier)


We are ready to apply F2T conversion.

When creating an `F2TConverter`, users can pass a `custom_editor` to perform model-specific rewritings.
This transform will be applied right before removing identity operations from the target `GraphModule`.

As part of its logic, F2T conversion must perform some semantic analysis of the input `GraphModule`.
Since `F2TConverter`s can not infer all the semantics automatically, the user must feed the shape and floating-point scale of the network's inputs.


In [26]:
# F2T conversion and ONNX exporting require structural information about the input
x, _ = next(iter(valid_loader))
x = x[0].unsqueeze(0)

# set the network in evaluation mode to "freeze" the parameters of batch normalisations
mnv1fq_init = mnv1fq_init.to(torch.device('cpu'))  # TODO: the `Tensor`s generated inside the F2T conversion flow are generated for CPU
mnv1fq_init.eval()

# perform the conversion
f2tconverter = qe.f2t.F2T24bitConverter(custom_editor=MNv1HeadRewriter())
mnv1tq = f2tconverter(mnv1fq_init, {'x': {'shape': x.shape, 'scale': torch.Tensor([ImageNetStats['quantise']['scale']])}})


To validate the performance of the integerised network, we need to pass it integerised data.
Thus, we create a `DataLoader` yielding `Tensor` images with integer components.

Although these integers are represented as floating-point numbers, we note that the ranges of typical digital data points (e.g., INT8 or UINT8) can be represented without any loss when embedded in standard floating-point ranges (e.g., FP32 or FP64).
Thus, although networks integerised using QuantLib still use floating-point arithmetic, in practice, we observed that their outputs coincide with those of truly integer networks.


In [27]:
# create the validation `DataLoader` returning integerised (UINT8) images
int_valid_transform = ImageNetTransform('valid', integerise=True)
int_valid_loader = loader_factory.get_dataloader('valid', int_valid_transform, per_gpu_batch_size, num_workers=n_cpus)


Let's evaluate our true-quantised MobileNetV1.

In [28]:
mnv1tq = mnv1tq.to(device=device)

# set the network in evaluation mode to "freeze" the parameters of batch normalisations
mnv1tq.eval()

# evaluate the network on the validation set
mnv1tq_perf = evaluate_network(int_valid_loader, mnv1tq, device)
print("Accuracy (true-quantised, trained): {:6.2f}%.".format(mnv1tq_perf.accuracy))


Accuracy (true-quantised, trained):  68.32%.


We notice a minor 0.1% accuracy drop with respect to the fake-quantised network.
Likely, this discrepancy is due to the propagation of small numerical differences that arise when replacing floating-point batch-normalisations with integerised requantisations.

Note that even after a single fine-tuning epoch, the performance of our integerised MobileNetV1 is still very close to the original floating-point one.


## Part 4: export a backend-specific ONNX model

The QuantLib `backends` package contains the abstractions required to export **ONNX** models and annotate them with backend-specific information.
To demonstrate its usage, we consider the [DORY](https://github.com/pulp-platform/dory) backend.

First, we create the folder to host DORY-specific files.


In [29]:
import shutil

backend_name = 'DORY'
path_export = os.path.join(os.curdir, backend_name)

if os.path.isdir(path_export):  # remove old files
    shutil.rmtree(path_export)
os.mkdir(path_export)


Then, we export an annotated ONNX model.

In [30]:
import quantlib.backends as qb

x = x.to(torch.device('cpu'))
mnv1tq = mnv1tq.to(device=torch.device('cpu'))

exporter = qb.dory.DORYExporter()
exporter.export(network=mnv1tq, input_shape=x.shape, path=path_export)


We can use [Netron](https://netron.app/) to visually inspect that the exported model is an integerised program.

DORY can also verify whether the exported ONNX model can be compiled into an integerised program for PULP platforms.
To perform this consistency check, users must dump the input and features maps associated with an example data point.


In [31]:
exporter.dump_features(network=mnv1tq, x=x, path=path_export)

## Conclusion

We have reached the bottom of the part of the deep learning stack covered by QuantLib.
If you are interested in graph optimisations and code generation, you can read the [DORY paper](https://ieeexplore.ieee.org/document/9381618) and check out the DORY repository.

Cheers!
