# Train a Quantized Lenet5 on MNIST with Brevitas

In [1]:
pip install tqdm --upgrade

Defaulting to user installation because normal site-packages is not writeable
Requirement already up-to-date: tqdm in /tmp/home_dir/.local/lib/python3.8/site-packages (4.66.2)
Note: you may need to restart the kernel to use updated packages.


In [2]:
pip install tensorboard

Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.


In [3]:
from torch.utils.data import Dataset

In [4]:
import onnx
import torch

# Load Dataset <a id="train_qnn"></a>


In [5]:
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision import datasets

train_val_dataset = datasets.MNIST(root="./datasets/", train=True, download=False, transform=transforms.ToTensor())
test_dataset = datasets.MNIST(root="./datasets", train=False, download=False, transform=transforms.ToTensor())

# Calculate mean and std
imgs = torch.stack([img for img, _ in train_val_dataset], dim=0)

mean = imgs.view(1, -1).mean(dim=1)    # or imgs.mean()
std = imgs.view(1, -1).std(dim=1)     # or imgs.std()

mnist_transforms = transforms.Compose([transforms.ToTensor(),
                                       transforms.Normalize(mean=mean, std=std)])

train_val_dataset = datasets.MNIST(root="./datasets/", train=True, download=False, transform=mnist_transforms)
test_dataset = datasets.MNIST(root="./datasets/", train=False, download=False, transform=mnist_transforms)

train_size = int(0.9 * len(train_val_dataset))
val_size = len(train_val_dataset) - train_size

train_dataset, val_dataset = torch.utils.data.random_split(dataset=train_val_dataset, lengths=[train_size, val_size])

BATCH_SIZE = 32

train_dataloader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=True)


# Define a PyTorch Device <a id='define_pytorch_device'></a> 

GPUs can significantly speed-up training of deep neural networks. We check for availability of a GPU and if so define it as target device.


In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Target device: " + str(device))

Target device: cpu


  return torch._C._cuda_getDeviceCount() > 0


# Define the Quantized MLP Model <a id='define_quantized_mlp'></a>

We'll now define an MLP model that will be trained to perform inference with quantized weights and activations.
For this, we'll use the quantization-aware training (QAT) capabilities offered by [Brevitas](https://github.com/Xilinx/brevitas).

Our MLP will have four fully-connected (FC) layers in total: three hidden layers with 64 neurons, and a final output layer with a single output, all using 2-bit weights. We'll use 2-bit quantized ReLU activation functions, and apply batch normalization between each FC layer and its activation.

In case you'd like to experiment with different quantization settings or topology parameters, we'll define all these topology settings as variables.

In [7]:
from torch import nn

# Setting seeds for reproducibility
torch.manual_seed(0)

class lenet5(nn.Module):
    def __init__(self):
        super().__init__()
        self.feature = nn.Sequential(
            #1
            nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2),   # 28*28->32*32-->28*28
            nn.Tanh(),
            nn.AvgPool2d(kernel_size=2, stride=2),  # 14*14
            
            #2
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1),  # 10*10
            nn.Tanh(),
            nn.AvgPool2d(kernel_size=2, stride=2),  # 5*5
            
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features=16*5*5, out_features=120),
            nn.Tanh(),
            nn.Linear(in_features=120, out_features=84),
            nn.Tanh(),
            nn.Linear(in_features=84, out_features=10),
        )
        
    def forward(self, x):
        return self.classifier(self.feature(x))

model = lenet5()
model.to(device)

lenet5(
  (feature): Sequential(
    (0): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): Tanh()
    (2): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (3): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
    (4): Tanh()
    (5): AvgPool2d(kernel_size=2, stride=2, padding=0)
  )
  (classifier): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=400, out_features=120, bias=True)
    (2): Tanh()
    (3): Linear(in_features=120, out_features=84, bias=True)
    (4): Tanh()
    (5): Linear(in_features=84, out_features=10, bias=True)
  )
)

In [8]:
from torchinfo import summary

summary(model=model, input_size=(1, 1, 28, 28), col_width=20,
                  col_names=['input_size', 'output_size', 'num_params', 'trainable'], row_settings=['var_names'], verbose=0)

Layer (type (var_name))                  Input Shape          Output Shape         Param #              Trainable
lenet5 (lenet5)                          [1, 1, 28, 28]       [1, 10]              --                   True
├─Sequential (feature)                   [1, 1, 28, 28]       [1, 16, 5, 5]        --                   True
│    └─Conv2d (0)                        [1, 1, 28, 28]       [1, 6, 28, 28]       156                  True
│    └─Tanh (1)                          [1, 6, 28, 28]       [1, 6, 28, 28]       --                   --
│    └─AvgPool2d (2)                     [1, 6, 28, 28]       [1, 6, 14, 14]       --                   --
│    └─Conv2d (3)                        [1, 6, 14, 14]       [1, 16, 10, 10]      2,416                True
│    └─Tanh (4)                          [1, 16, 10, 10]      [1, 16, 10, 10]      --                   --
│    └─AvgPool2d (5)                     [1, 16, 10, 10]      [1, 16, 5, 5]        --                   --
├─Sequential (classifi

# Train the QNN <a id="train_qnn"></a>

We provide two options for training below: you can opt for training the model from scratch (slower) or use a pre-trained model (faster). The first option will give more insight into how the training process works, while the second option will likely give better accuracy.

In [9]:
from sklearn.metrics import accuracy_score
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.001)

In [10]:
from tqdm.notebook import tqdm
from torch.utils.tensorboard import SummaryWriter

from datetime import datetime
import os

# Experiment tracking
timestamp = datetime.now().strftime("%Y-%m-%d")
experiment_name = "MNIST"
model_name = "LeNet5V1"
log_dir = os.path.join("runs", timestamp, experiment_name, model_name)
writer = SummaryWriter(log_dir)

# device-agnostic setup
#device = 'cuda' if torch.cuda.is_available() else 'cpu'
#accuracy = accuracy.to(device)
#model_lenet5v1 = model_lenet5v1.to(device)

EPOCHS = 12

for epoch in tqdm(range(EPOCHS)):
    # Training loop
    train_loss, train_acc = 0.0, 0.0

    for X, y in train_dataloader:
        X, y = X.to(device), y.to(device)
        
        model.train()
        
        y_pred = model(X) 
        y_predt = []
        for weight in y_pred:
            y_predt.append(torch.argmax(weight))

        #print(y_predt,y)
        loss = loss_fn(y_pred, y)
        train_loss += loss.item()
        
        acc = accuracy_score(y_predt, y)
        train_acc += acc
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    train_loss /= len(train_dataloader)
    train_acc /= len(train_dataloader)
        
    # Validation loop
    val_loss, val_acc = 0.0, 0.0
    model.eval()
    #with torch.inference_mode():
    for X, y in val_dataloader:
        X, y = X.to(device), y.to(device)
            
        y_pred = model(X)

        y_predt = []
        for weight in y_pred:
            y_predt.append(torch.argmax(weight))
            
        loss = loss_fn(y_pred, y)
        val_loss += loss.item()
            
        acc = accuracy_score(y_predt, y)
        val_acc += acc
            
    val_loss /= len(val_dataloader)
    val_acc /= len(val_dataloader)
        
    writer.add_scalars(main_tag="Loss", tag_scalar_dict={"train/loss": train_loss, "val/loss": val_loss}, global_step=epoch)
    writer.add_scalars(main_tag="Accuracy", tag_scalar_dict={"train/acc": train_acc, "val/acc": val_acc}, global_step=epoch)
    
    print(f"Epoch: {epoch}| Train loss: {train_loss: .5f}| Train acc: {train_acc: .5f}| Val loss: {val_loss: .5f}| Val acc: {val_acc: .5f}")

  0%|          | 0/12 [00:00<?, ?it/s]

Epoch: 0| Train loss:  0.21497| Train acc:  0.93767| Val loss:  0.07848| Val acc:  0.97507
Epoch: 1| Train loss:  0.07235| Train acc:  0.97740| Val loss:  0.05275| Val acc:  0.98421
Epoch: 2| Train loss:  0.05167| Train acc:  0.98376| Val loss:  0.04760| Val acc:  0.98388
Epoch: 3| Train loss:  0.04124| Train acc:  0.98676| Val loss:  0.04755| Val acc:  0.98504
Epoch: 4| Train loss:  0.03346| Train acc:  0.98919| Val loss:  0.04873| Val acc:  0.98554
Epoch: 5| Train loss:  0.02875| Train acc:  0.99115| Val loss:  0.05294| Val acc:  0.98438
Epoch: 6| Train loss:  0.02547| Train acc:  0.99167| Val loss:  0.04079| Val acc:  0.98853
Epoch: 7| Train loss:  0.02272| Train acc:  0.99232| Val loss:  0.04954| Val acc:  0.98554
Epoch: 8| Train loss:  0.01831| Train acc:  0.99361| Val loss:  0.04276| Val acc:  0.98853
Epoch: 9| Train loss:  0.01691| Train acc:  0.99458| Val loss:  0.04437| Val acc:  0.98737
Epoch: 10| Train loss:  0.01536| Train acc:  0.99469| Val loss:  0.04605| Val acc:  0.9877

# Test the QNN <a id="train_qnn"></a>

In [11]:
test_loss, test_acc = 0, 0

model.eval()
for X, y in test_dataloader:
    X, y = X.to(device), y.to(device)
    y_pred = model(X)

    y_predt = []
    for weight in y_pred:
        y_predt.append(torch.argmax(weight))
        
    test_loss += loss_fn(y_pred, y)
    test_acc += accuracy_score(y_predt, y)
        
test_loss /= len(test_dataloader)
test_acc /= len(test_dataloader)

print(f"Test loss: {test_loss: .5f}| Test acc: {test_acc: .5f}")

Test loss:  0.04169| Test acc:  0.98722


In [12]:
# Save the Brevitas model to disk
torch.save(model.state_dict(), "lenet5.pth")

# Export to FINN-ONNX <a id="export_finn_onnx" ></a>


[ONNX](https://onnx.ai/) is an open format built to represent machine learning models, and the FINN compiler expects an ONNX model as input. We'll now export our network into ONNX to be imported and used in FINN for the next notebooks. Note that the particular ONNX representation used for FINN differs from standard ONNX, you can read more about this [here](https://finn.readthedocs.io/en/latest/internals.html#intermediate-representation-finn-onnx).

You can see below how we export a trained network in Brevitas into a FINN-compatible ONNX representation. Note how we create a `QuantTensor` instance with dummy data to tell Brevitas how our inputs look like, which will be used to set the input quantization annotation on the exported model.

In [13]:
import brevitas.onnx as bo
import numpy as np
from brevitas.quant_tensor import QuantTensor

ready_model_filename = "lenet5.onnx"
input = torch.ones([1, 1, 28, 28], dtype = torch.float32)

#Move to CPU before export
model.cpu()

# Export to ONNX
torch.onnx.export(
    model, input, ready_model_filename
)

print("Model saved to %s" % ready_model_filename)

Model saved to lenet5.onnx


# Compare FINN & Brevitas execution <a id="compare_brevitas"></a>

Load  brevitas model.

In [14]:
brevitas_model = lenet5()
trained_state_dict = torch.load("lenet5.pth")
brevitas_model.load_state_dict(trained_state_dict, strict=False)

<All keys matched successfully>

In [15]:
def inference_with_brevitas(current_inp):
    output = torch.argmax(brevitas_model(current_inp))
    brevitas_output = output.detach().numpy().tolist()
    #print(brevitas_output[0])
    return  brevitas_output

Now that we have the model in .onnx format, we can work with it using FINN. To import it into FINN, we'll use the [`ModelWrapper`](https://finn.readthedocs.io/en/latest/source_code/finn.core.html#qonnx.core.modelwrapper.ModelWrapper). It is a wrapper around the ONNX model which provides several helper functions to make it easier to work with the model.

In [16]:
from qonnx.core.modelwrapper import ModelWrapper

ready_model_filename = "lenet5.onnx"
model_for_sim = ModelWrapper(ready_model_filename)


In [17]:
from qonnx.core.datatype import DataType

finnonnx_in_tensor_name = model_for_sim.graph.input[0].name
finnonnx_out_tensor_name = model_for_sim.graph.output[0].name
print("Input tensor name: %s" % finnonnx_in_tensor_name)
print("Output tensor name: %s" % finnonnx_out_tensor_name)
finnonnx_model_in_shape = model_for_sim.get_tensor_shape(finnonnx_in_tensor_name)
finnonnx_model_out_shape = model_for_sim.get_tensor_shape(finnonnx_out_tensor_name)
print("Input tensor shape: %s" % str(finnonnx_model_in_shape))
print("Output tensor shape: %s" % str(finnonnx_model_out_shape))
finnonnx_model_in_dt = model_for_sim.get_tensor_datatype(finnonnx_in_tensor_name)
finnonnx_model_out_dt = model_for_sim.get_tensor_datatype(finnonnx_out_tensor_name)
print("Input tensor datatype: %s" % str(finnonnx_model_in_dt.name))
print("Output tensor datatype: %s" % str(finnonnx_model_out_dt.name))
print("List of node operator types in the graph: ")
print([x.op_type for x in model_for_sim.graph.node])

Input tensor name: input.1
Output tensor name: 24
Input tensor shape: [1, 1, 28, 28]
Output tensor shape: [1, 10]
Input tensor datatype: FLOAT32
Output tensor datatype: FLOAT32
List of node operator types in the graph: 
['Conv', 'Tanh', 'Pad', 'AveragePool', 'Conv', 'Tanh', 'Pad', 'AveragePool', 'Flatten', 'Gemm', 'Tanh', 'Gemm', 'Tanh', 'Gemm']


In [18]:
from qonnx.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames, RemoveStaticGraphInputs
from qonnx.transformation.infer_shapes import InferShapes
from qonnx.transformation.infer_datatypes import InferDataTypes
from qonnx.transformation.fold_constants import FoldConstants

model_for_sim = model_for_sim.transform(InferShapes())
model_for_sim = model_for_sim.transform(FoldConstants())
model_for_sim = model_for_sim.transform(GiveUniqueNodeNames())
model_for_sim = model_for_sim.transform(GiveReadableTensorNames())
model_for_sim = model_for_sim.transform(InferDataTypes())
model_for_sim = model_for_sim.transform(RemoveStaticGraphInputs())

verif_model_filename = "lenet5.onnx"
model_for_sim.save(verif_model_filename)

In [19]:
import finn.core.onnx_exec as oxe

def inference_with_finn_onnx(current_inp):
    finnonnx_in_tensor_name = model_for_sim.graph.input[0].name
    finnonnx_model_in_shape = model_for_sim.get_tensor_shape(finnonnx_in_tensor_name)
    finnonnx_out_tensor_name = model_for_sim.graph.output[0].name
    # convert input to numpy for FINN
    current_inp = current_inp.detach().numpy()
    # reshape to expected input (add 1 for batch dimension)
    current_inp = current_inp.reshape(finnonnx_model_in_shape)
    # create the input dictionary
    input_dict = {finnonnx_in_tensor_name : current_inp} 
    # run with FINN's execute_onnx
    output_dict = oxe.execute_onnx(model_for_sim, input_dict)
    #get the output tensor
    finn_outputs = output_dict[finnonnx_out_tensor_name] 
    #return finn_output
    finn_outputs = finn_outputs.tolist()
    finn_output = finn_outputs[0].index(max(finn_outputs[0]))
    return finn_output

In [20]:
# See random images with their labels
torch.manual_seed(42)  # setting random seed

ok = 0
nok = 0

for i in range(1, len(test_dataset) - 1):
    img, label_gt = test_dataset[i]
    img_temp = img.unsqueeze(dim=0).to(device)
    brevitas_output = inference_with_brevitas(img_temp)
    finn_output = inference_with_finn_onnx(img)
    # compare the outputs
    ok += 1 if finn_output == brevitas_output else 0
    nok += 1 if finn_output != brevitas_output else 0

print("ok = ", ok)
print("nok = ", nok)

ok =  9998
nok =  0
