This example guides a first-time user of QuantLib into the quantization process, using a small pretrained network and going through post-training per-layer quantization (i.e., representing weight and activation tensors as integers) and deployment (i.e., organizing operations so that they are an accurate representation of behavior in integer-based hardware).

We will see how this operates through three stages: *FloatingPoint*, *FakeQuantized*, and *TrueQuantized*.
QuantLib uses float32 tensors to represent data at all four stages - including *TrueQuantized*. 
This means that QuantLib code does not need special hardware support for integers to run on GPUs.

Let us start by 1) performing the necessary imports, and 2) setting the target device.

In [1]:
#@title Imports & Set device

#basic
import numpy as np
from pandas import DataFrame
from copy import deepcopy
from tqdm import tqdm
import os

#torch
import torch; print('\nPyTorch version in use:', torch.__version__, '\ncuda avail: ', torch.cuda.is_available())
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

#quantlib!
import quantlib
import quantlib.algorithms as qa
import quantlib.editing.graphs as qg
import quantlib.editing.editing as qe
import quantlib.backends.dory as qd

device = 'cpu' # torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device: %s' % device)


PyTorch version in use: 1.9.0+cu102 
cuda avail:  True
Device: cpu


The first real step is to define the network topology. 
This works exactly like in a "standard" PyTorch script, using regular torch.nn.Module instances. 
QuantLib can transform most layers defined in torch.nn into its own representations. 
It will also perform a process of *Canonicalization* to make sure that common topological constructions,
such as flattenization, residual layers terminating in an addition, and others are represented in a consistent fashion.

As QuantLib exports `torch.fx` internally, we want all graphs related to identical functionality to be themselves identical.

In [2]:
#@title Define ExampleNet

class ExampleNet(nn.Module):
    def __init__(self):
        super(ExampleNet, self).__init__()
        
        self.conv1 = nn.Conv2d(1, 4, 3, padding=1, bias=False)
        self.bn1   = nn.BatchNorm2d(4)
        self.relu1 = nn.ReLU() # <== Module, not Function!

        self.conv2 = nn.Conv2d(4, 20, 3, padding=1, stride=2, bias=False)
        self.bn2   = nn.BatchNorm2d(20)
        self.relu2 = nn.ReLU() # <== Module, not Function!

        self.conv3 = nn.Conv2d(20, 40, 3, padding=1, stride=2, bias=False)
        self.bn3   = nn.BatchNorm2d(40)
        self.relu3 = nn.ReLU() # <== Module, not Function!

        self.conv4 = nn.Conv2d(40, 80, 3, padding=1, stride=2, bias=False)
        self.bn4   = nn.BatchNorm2d(80)
        self.relu4 = nn.ReLU() # <== Module, not Function!
        
        self.flatten = nn.Flatten()

        self.fc1 = nn.Linear(80 * 4**2, 500, bias=False)
        #self.fcbn = nn.BatchNorm1d(500)
        self.fcrelu1 = nn.ReLU() # <== Module, not Function!
        self.fc2 = nn.Linear(500, 10, bias=False)

    def forward(self, x):
        x = self.relu1(self.bn1(self.conv1(x)))
        x = self.relu2(self.bn2(self.conv2(x)))
        x = self.relu3(self.bn3(self.conv3(x)))
        x = self.relu4(self.bn4(self.conv4(x)))
        x = x.view(x.size(0), -1)
        # x = self.flatten(x)
        x = self.fcrelu1(self.fc1(x))
        x = self.fc2(x)
        # output = F.log_softmax(x, dim=1) # <== the softmax operation does not need to be quantized, we can keep it as it is
        return x
        
model = ExampleNet().to(device)

Here, we define the testing functions (MNIST has no validation set).

These are essentially identical to regular PyTorch code, with only one difference: testing (and validation) functions 
have a switch to support the production of non-negative integer data.

This is important to test the last stage of quantization, i.e., *TrueQuantized*.

Of course, this change might also be effectively performed inside the data loaders; 
in this example, we use standard `torchvision` data loaders for MNIST.

In [3]:
#@title Define Metrics and validation function

# convenience class to keep track of averages
class Metric(object):
    def __init__(self, name):
        self.name = name
        self.sum  = 0
        self.n    = 0
    def update(self, value):
        self.sum += value
        self.n += 1
    @property
    def avg(self):
        return self.sum / self.n

def validate(model, device, dataloader, verbose=True, integer=False):
    model.eval()
    loss    = 0
    correct = 0
    acc     = Metric('test_acc')
    with tqdm(
        total=len(dataloader), desc='Validation', disable=not verbose,
        ) as t:
        with torch.no_grad():
            for data, target in dataloader:
                if integer:      # support for production of
                    data *= 255  # non-negative integer data
                data, target = data.to(device), target.to(device)
                output = model(data)
                loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
                pred = output.argmax(dim=1) # get index of largest log-probability
                correct += pred.eq(target).sum().item()
                acc.update(pred.eq(target).float().mean().item())
                t.set_postfix({'acc': acc.avg})
                t.update(1)
    loss /= len(dataloader.dataset)
    return acc.avg


# calibration set
Mcalib = 1024 # calibration set size
train_set = datasets.MNIST('./data', train=True , download=True, transform=transforms.ToTensor())
calib_set = torch.utils.data.Subset(train_set, indices=np.random.permutation(len(train_set))[:Mcalib])
del train_set # we load a pretrained model, we won't train it in this script. Just needed for calibration!

# validation set
valid_set = datasets.MNIST('./data', train=False, download=True, transform=transforms.ToTensor())

# set up the dataloaders
kwargs = {'num_workers': 1, 'pin_memory': True} if torch.cuda.is_available() else {}
calib_loader = torch.utils.data.DataLoader(calib_set, batch_size=128, shuffle=False, drop_last=True, **kwargs)
valid_loader = torch.utils.data.DataLoader(valid_set, batch_size=128, shuffle=False, drop_last=True, **kwargs)

os.system('rm -rf examplenet.pt')
os.system('wget https://github.com/MarcelloZanghieri2/NeMO_tutorial/blob/main/smallernet_4.pt?raw=true')
os.system('mv smallernet_4.pt?raw=true examplenet.pt')

model = ExampleNet().to(device)
state_dict = torch.load('examplenet.pt', map_location='cpu')
model.load_state_dict(state_dict, strict=True)

acc = validate(model, device, valid_loader)
print("\n\nFullPrecision accuracy: %.3f" % (acc))

  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
Validation: 100%|██████████| 78/78 [00:03<00:00, 19.74it/s, acc=0.99] 



FullPrecision accuracy: 0.990





The first step toward quantization is tracing the graph of the model using the QuantLib tracer: the following cell performs this, then it prints a user-readable summary of the traced graph.
To simplify the operation, we use the tracer embedded in `quantlib.graph.fx`, which wraps the `torch.fx` tracer.

In [4]:
#@title Trace floating-point model graph and print it in human-readable format

# Symbolic trace of the graph
model_fp = qg.fx.quantlib_symbolic_trace(root=model)

# Print the graph in tabular format
model_fp.graph.print_tabular()

opcode       name     target    args               kwargs
-----------  -------  --------  -----------------  --------
placeholder  x        x         ()                 {}
call_module  conv1    conv1     (x,)               {}
call_module  bn1      bn1       (conv1,)           {}
call_module  relu1    relu1     (bn1,)             {}
call_module  conv2    conv2     (relu1,)           {}
call_module  bn2      bn2       (conv2,)           {}
call_module  relu2    relu2     (bn2,)             {}
call_module  conv3    conv3     (relu2,)           {}
call_module  bn3      bn3       (conv3,)           {}
call_module  relu3    relu3     (bn3,)             {}
call_module  conv4    conv4     (relu3,)           {}
call_module  bn4      bn4       (conv4,)           {}
call_module  relu4    relu4     (bn4,)             {}
call_method  size     size      (relu4, 0)         {}
call_method  view     view      (relu4, size, -1)  {}
call_module  fc1      fc1       (view,)            {}
call_module  fcrel

The first "real" step for 8-bit quantization is the Float2Fake conversion, which is triggered by calling `quantlib.editing.float2fake.F2F8bitPACTRoundingConverter()`.
This is a convenient editor that wraps inside a lot of useful transformations, such as:
 - canonicalization, i.e., transforming the network to a "canonical" format as previously discussed;
 - proper float2fake quantization, which replaces standard `nn.Conv2d`, `nn.Linear`, `nn.ReLU`, etc., into quantized modules - in  this case, `PACTConv2d`, `PACTLinear`, and `PACTReLU`, respectively;
 - folding of bias parameters into batch-normalization layers;
 - introduce quantization of non-activated tensors before Add layers;
 - set up rounding of weights and activations to avoid systematic bias in quantization; the latter is performed by folding the rounding bias into batch-normalization layers whenever possibile. This operation needs to be completed after calibration.
 
`F2F8bitPACTRoundingConverter()` is a specialized version of `F2F8bitPACTConverter()`, and for special needs, it might be necessary to modify it by subclassing or adaptation. But for our purposes here, it is good to go.

In [5]:
#float 2 fake
f2fconverter = qe.float2fake.F2F8bitPACTRoundingConverter()
model_fq = f2fconverter(model_fp)

# set validation state
model_fq.eval()


ExampleNet(
  (conv1): PACTConv2d(1, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu1): PACTReLU()
  (conv2): PACTConv2d(4, 20, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (bn2): BatchNorm2d(20, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu2): PACTReLU()
  (conv3): PACTConv2d(20, 40, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (bn3): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu3): PACTReLU()
  (conv4): PACTConv2d(40, 80, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (bn4): BatchNorm2d(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu4): PACTReLU()
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (fc1): PACTLinear(in_features=1280, out_features=500, bias=False)
  (fcrelu1): PACTReLU()
  (fc2): PACTLinear(in_features=500, o

We can also print the `torch.fx` graph to see how the model has been subtly changed with respect to the Float version: notice in particular that the `flatten` node has been introduced.

In [6]:
# Print the graph in tabular format
model_fq.graph.print_tabular()

opcode       name     target    args        kwargs
-----------  -------  --------  ----------  --------
placeholder  x        x         ()          {}
call_module  conv1    conv1     (x,)        {}
call_module  bn1      bn1       (conv1,)    {}
call_module  relu1    relu1     (bn1,)      {}
call_module  conv2    conv2     (relu1,)    {}
call_module  bn2      bn2       (conv2,)    {}
call_module  relu2    relu2     (bn2,)      {}
call_module  conv3    conv3     (relu2,)    {}
call_module  bn3      bn3       (conv3,)    {}
call_module  relu3    relu3     (bn3,)      {}
call_module  conv4    conv4     (relu3,)    {}
call_module  bn4      bn4       (conv4,)    {}
call_module  relu4    relu4     (bn4,)      {}
call_module  flatten  flatten   (relu4,)    {}
call_module  fc1      fc1       (flatten,)  {}
call_module  fcrelu1  fcrelu1   (fc1,)      {}
call_module  fc2      fc2       (fcrelu1,)  {}
output       output   output    (fc2,)      {}


In most cases, we would see that the model at this stage is not fully functional due to lack of alignment between the quantization parameters (scale `eps` in particular) and the actual activations flowing through the network. For this MNIST experiment, it might even work! But in general, we need to
 1. calibrate the network with real activation data
 2. complete the rounding procedure

In [7]:
# collect statistics about the floating-point `Tensor`s passing through the quantisers, so that we can better fit the quantisers' hyper-parameters
with qe.float2fake.calibration(model_fq):
    acc = validate(model_fq, device, calib_loader)

# adds rounding to all PACT operators
rounder =  qe.float2fake.F2F8bitPACTRounder()
model_fq_rounded = rounder(model_fq)

model_fq_rounded.to(device)
model_fq_rounded.eval()

acc = validate(model_fq_rounded, device, valid_loader)
print("\nFakeQuantized with calibration+rounding: accuracy: %.3f" % (acc))

Validation: 100%|██████████| 8/8 [00:00<00:00, 24.07it/s, acc=0.997]
Validation: 100%|██████████| 78/78 [00:02<00:00, 32.95it/s, acc=0.99] 


FakeQuantized with calibration+rounding: accuracy: 0.990





The *FakeQuantized* network includes some quantization information (mainly for what concerns linear and activation layers), but it is not fully equivalent to a wholly quantized network.
The Fake2True transformation completes the quantization of the network by making it so that all layers consume and produce integer Tensors, by means of a procedure of `eps` propagation -- i.e., propagation of scales throughout the network.
To kickstart this process, QuantLib needs:
 1. an example input (possibly with random values -- only the shape is necessary)
 2. the scale to be attributed to input, i.e., the real value of a Tensor element represented with `1`

The default converter is called `quantlib.editing.f2t.F2T24bitConverter()`. The number of bits mentioned refers to the number of bits allocated to intermediate accumulation values.

In [8]:
# get exaple input
x, _ = next(iter(valid_loader))
x = x[0].unsqueeze(0)

# convert to TrueQuantized with default 24-bit converter
f2tconverter = qe.f2t.F2T24bitConverter()
model_tq = f2tconverter(model_fq_rounded, {'x': {'shape': x.shape, 'scale': torch.tensor((0.0039216,))}})



If we print the *TrueQuantized* graph in readable format, we may see it looks different from before: all layers changed name, indicating their "integerized" nature; moreover, before the output you can stop a `EpsTunnel` layer that performs an important function: takes the integer output of the network and "transforms" it into a floating-point value, to ensure output consistency.
The input of the *TrueQuantized* model, on the other hand, must already be integerized when put as input tot the network, hence we use `integer=True` when running validation.

In [9]:
# Print the graph in tabular format
model_tq.graph.print_tabular()

# Test the network
acc = validate(model_tq, device, valid_loader, integer=True)
print("\nTrueQuantized: accuracy: %.3f" % (acc))

Validation:   0%|          | 0/78 [00:00<?, ?it/s]

opcode       name                                                   target                                             args                                                      kwargs
-----------  -----------------------------------------------------  -------------------------------------------------  --------------------------------------------------------  --------
placeholder  x                                                      x                                                  ()                                                        {}
call_module  ql_eps_qconv2d_eps_integeriser_139755626488464__4_     QL_EpsQConv2dEpsIntegeriser_139755626488464__4_    (x,)                                                      {}
call_module  ql_eps_bn2d_qre_lueps_requantiser_139755626052688__4_  QL_EpsBN2dQReLUEpsRequantiser_139755626052688__4_  (ql_eps_qconv2d_eps_integeriser_139755626488464__4_,)     {}
call_module  ql_eps_qconv2d_eps_integeriser_139755626488464__3_     QL_EpsQConv2dEpsIntege

Validation: 100%|██████████| 78/78 [00:01<00:00, 46.22it/s, acc=0.99] 


TrueQuantized: accuracy: 0.990





If we want to export the network, e.g., to DORY, we need to remove the final `EpsTunnel`. QuantLib contains a suitable editor, which also returns in print format the integer-to-float scaling factors.

In [10]:
epsremover = qe.f2t.FinalEpsTunnelRemover()
model_tq_removed = epsremover(model_tq)

[FinalEpsTunnelRemover] output: removing EpsTunnel with scaling factor tensor([[0.0003171186253894, 0.0002228131197626, 0.0001979214721359,
         0.0002190211525885, 0.0001621314731892, 0.0001564017584315,
         0.0002467445738148, 0.0001997630897677, 0.0001304127363255,
         0.0002063974534394]])
[FinalEpsTunnelRemover] output: outputs will need to be scaled *externally* to maintain program semantics.


Final export of data to DORY for further testing can use the QuantLib embedded APIs:

In [11]:
doryexporter = qd.DORYExporter()
doryexporter.export(model_tq, x.shape, ".")
doryexporter.dump_features(model_tq, x, ".")