# FINN - End-to-End Flow
-----------------------------------------------------------------

In this notebook, we will show how to take a simple, binarized, fully-connected network trained on the CIFAR10 data set and take it all the way down to a customized bitfile running on a PYNQ board. 

This notebook is quite lengthy, and some of the cells (involving Vivado synthesis) may take up to an hour to finish running. To let you save and resume your progress, we will save the intermediate ONNX models that are generated in the various steps to disk, so that you can jump back directly to where you left off.

In [1]:
import os

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
import torch.nn.functional as F
from tqdm import tqdm
from torch.optim.lr_scheduler import OneCycleLR, ReduceLROnPlateau
from finn.util.visualization import showSrc, showInNetron
%matplotlib inline

## 1. Brevitas export <a id='brev_exp'></a>
FINN expects an ONNX model as input. This can be a model trained with [Brevitas](https://github.com/Xilinx/brevitas). Brevitas is a PyTorch library for quantization-aware training and the FINN Docker image comes with several [example Brevitas networks](https://github.com/Xilinx/brevitas/tree/master/src/brevitas_examples/bnn_pynq). 

### Load Dataset

In [2]:
from torchvision import transforms
import numpy as np

class GetTransforms():
    '''Returns a list of transformations when type as requested amongst train/test
       Transforms('train') = list of transforms to apply on training data
       Transforms('test') = list of transforms to apply on testing data'''

    def __init__(self):
        pass

    def trainparams(self):
        train_transformations = [ #resises the image so it can be perfect for our model.
            transforms.RandomHorizontalFlip(), # FLips the image w.r.t horizontal axis
            transforms.RandomRotation((-7,7)),     #Rotates the image to a specified angel
            transforms.RandomAffine(0, shear=10, scale=(0.8,1.2)), #Performs actions like zooms, change shear angles.
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), # Set the color params
            transforms.ToTensor(), # comvert the image to tensor so that it can work with torch
            transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261)) #Normalize all the images
            ]

        return train_transformations

    def testparams(self):
        test_transforms = [
            transforms.ToTensor(),
            transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))
        ]
        return test_transforms

In [3]:
transformations = GetTransforms()
train_transforms = transforms.Compose(transformations.trainparams())
test_transforms = transforms.Compose(transformations.testparams())


class GetCIFAR10_TrainData():
    def __init__(self, dir_name:str):
        self.dirname = dir_name

    def download_train_data(self):
        return datasets.CIFAR10('resnet18/data', train=True, download=True, transform=train_transforms)

    def download_test_data(self):
        return datasets.CIFAR10('resnet18/data', train=False, download=True, transform=test_transforms)

In [4]:
import os
data = GetCIFAR10_TrainData(os.chdir(".."))
trainset = data.download_train_data()
testset = data.download_test_data()
trainloader = torch.utils.data.DataLoader(trainset, batch_size=592,
                                          shuffle=True, num_workers=0)
testloader = torch.utils.data.DataLoader(testset, batch_size=592,
                                         shuffle=False, num_workers=0)

Files already downloaded and verified
Files already downloaded and verified


### Define a PyTorch Device

GPUs can significantly speed-up training of deep neural networks. We check for availability of a GPU and if so define it as target device.

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Target device: " + str(device))

Target device: cpu


### Define the Model 

In [6]:
from brevitas.nn import QuantConv2d, QuantLinear

torch.manual_seed(0)

weight_bit_width = 8

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, dropout=0.0):
        super(BasicBlock, self).__init__()
        self.conv1 = QuantConv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=True, weight_bit_width=weight_bit_width,quant_type="int")
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = QuantConv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=True, weight_bit_width=weight_bit_width,quant_type="int")
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                QuantConv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=True, weight_bit_width=weight_bit_width,quant_type="int"),
                nn.BatchNorm2d(self.expansion*planes)
            )
        self.dropout = dropout

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.dropout(out, p=self.dropout)
        out = self.bn2(self.conv2(out))
        out = F.dropout(out, p=self.dropout)
        out += self.shortcut(x)
        out = F.relu(out)
        out = F.dropout(out, p=self.dropout)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = QuantConv2d(in_planes, planes, kernel_size=1, bias=True, weight_bit_width=weight_bit_width,quant_type="int")
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = QuantConv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True, weight_bit_width=weight_bit_width,quant_type="int")
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = QuantConv2d(planes, self.expansion*planes, kernel_size=1, bias=True, weight_bit_width=weight_bit_width,quant_type="int")
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                QuantConv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=True, weight_bit_width=weight_bit_width,quant_type="int"),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=200, dropout=0.0):
        super(ResNet, self).__init__()
        self.in_planes = 64
        self.dropout = dropout

        self.conv1 = QuantConv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=True, weight_bit_width=weight_bit_width,quant_type="int")
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = QuantLinear(512*block.expansion, num_classes, bias=True, weight_bit_width=weight_bit_width,quant_type="int")

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride, dropout=self.dropout))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.dropout(out, p=self.dropout)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.adaptive_avg_pool2d(out, 1)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def ResNet18(num_classes=10, dropout=0.0):
    return ResNet(BasicBlock, [2,2,2,2], num_classes=num_classes, dropout=dropout)


print("Model define")

Model define


### Train and Test

In [7]:
classes = ["%s" % i for i in range(10)]

In [8]:
def test(model, device, test_loader, criterion, classes, test_losses, test_accs,
         misclassified_imgs, correct_imgs, is_last_epoch):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss +=criterion(output, target).item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            is_correct = pred.eq(target.view_as(pred))
            if is_last_epoch:
              misclassified_inds = (is_correct==0).nonzero()[:,0]
              for mis_ind in misclassified_inds:
                if len(misclassified_imgs) == 25:
                  break
                misclassified_imgs.append({
                    "target": target[mis_ind].cpu().numpy(),
                    "pred": pred[mis_ind][0].cpu().numpy(),
                    "img": data[mis_ind]
                })
              
              correct_inds = (is_correct==1).nonzero()[:,0]
              for ind in correct_inds:
                if len(correct_imgs) == 25:
                  break
                correct_imgs.append({
                    "target": target[ind].cpu().numpy(),
                    "pred": pred[ind][0].cpu().numpy(),
                    "img": data[ind]
                })
            correct += is_correct.sum().item()

    test_loss /= len(test_loader)
    test_losses.append(test_loss)
    
    test_acc = 100. * correct / len(test_loader.dataset)
    test_accs.append(test_acc)

    if test_acc >= 90.0:
        classwise_acc(model, device, test_loader, classes)

    print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset), test_acc))

In [9]:
def classwise_acc(model, device, test_loader, classes):
    class_correct = list(0. for i in range(10))
    class_total = list(0. for i in range(10))
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            c = (predicted == labels).squeeze()
            for i in range(4):
                label = labels[i]
                class_correct[label] += c[i].item()
                class_total[label] += 1
    # print class-wise test accuracies
    print()
    for i in range(10):
      print('Accuracy of %5s : %2d %%' % (
          classes[i], 100 * class_correct[i] / class_total[i]))
    print()

In [10]:
import os
import torch

In [11]:
# --- Helper function ---
def remove_export_handlers(model):
    count = 0
    for module in model.modules():
        if hasattr(module, "export_handler"):
            module.export_handler = None
            count += 1
    print(f"✅ Removed export_handler from {count} Quant layers.")


In [12]:
# --- Main export pipeline ---
# Step 1: Construct model
model = ResNet18(num_classes=10)

# Step 2: Load weights
trained_state_dict = torch.load("./models/quentresnet18_weight8.pth", map_location='cpu')
model.load_state_dict(trained_state_dict, strict=False)

# Step 3: Remove export_handler from all quant layers
remove_export_handlers(model)

# Step 4: Prepare for export
model.eval()
model.cpu()

✅ Removed export_handler from 105 Quant layers.


ResNet(
  (conv1): QuantConv2d(
    3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
    (input_quant): ActQuantProxyFromInjector(
      (_zero_hw_sentinel): StatelessBuffer()
    )
    (output_quant): ActQuantProxyFromInjector(
      (_zero_hw_sentinel): StatelessBuffer()
    )
    (weight_quant): WeightQuantProxyFromInjector(
      (_zero_hw_sentinel): StatelessBuffer()
      (tensor_quant): RescalingIntQuant(
        (int_quant): IntQuant(
          (float_to_int_impl): RoundSte()
          (tensor_clamp_impl): TensorClampSte()
          (delay_wrapper): DelayWrapper(
            (delay_impl): _NoDelay()
          )
        )
        (scaling_impl): StatsFromParameterScaling(
          (parameter_list_stats): _ParameterListStats(
            (first_tracked_param): _ViewParameterWrapper(
              (view_shape_impl): OverTensorView()
            )
            (stats): _Stats(
              (stats_impl): AbsMax()
            )
          )
          (stats_scaling_impl): _St

In [13]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum= 0.9)
test_losses, train_losses, test_accs, train_accs = [], [], [], []
misclassified_imgs, correct_imgs = [], []

In [14]:
print(model.conv1.weight.dtype)

torch.float32


In [14]:
#Test for accuracy
exact_acc = test(model, device, testloader, criterion, classes, test_losses,
                 test_accs, misclassified_imgs, correct_imgs,False)

  return super(Tensor, self).rename(names)



Accuracy of     0 : 100 %
Accuracy of     1 : 100 %
Accuracy of     2 : 75 %
Accuracy of     3 : 85 %
Accuracy of     4 : 100 %
Accuracy of     5 : 100 %
Accuracy of     6 : 100 %
Accuracy of     7 : 100 %
Accuracy of     8 : 100 %
Accuracy of     9 : 66 %

Test set: Average loss: 0.3311, Accuracy: 9024/10000 (90.24%)



# 2. Export to FINN-ONNX

ONNX is an open format built to represent machine learning models, and the FINN compiler expects an ONNX model as input. We'll now export our network into ONNX to be imported and used in FINN for the next notebooks. Note that the particular ONNX representation used for FINN differs from standard ONNX, you can read more about this here.

You can see below how we export a trained network in Brevitas into a FINN-compatible ONNX representation. Note how we create a QuantTensor instance with dummy data to tell Brevitas how our inputs look like, which will be used to set the input quantization annotation on the exported model.

In [16]:
from brevitas.export import export_brevitas_onnx
# Step 5: Export ONNX
onnx_export_path = "hardware/quantresnet18_weight8_files/quantresnet18_raw.onnx"
# Step 5: Export ONNX (keep initializers as inputs to keep weights in graph)
export_brevitas_onnx(
    model,
    input_t=torch.randn(1, 3, 32, 32),
    export_path=onnx_export_path,
    keep_initializers_as_inputs=True  # Changed here
)
print(f"✅ Exported to {onnx_export_path}")

[W NNPACK.cpp:53] Could not initialize NNPACK! Reason: Unsupported hardware.


✅ Exported to hardware/quantresnet18_weight8_files/quantresnet18_raw.onnx


## Compare FINN & Brevitas execution

Load brevitas model.

In [17]:
brevitas_model =  ResNet18(num_classes=10)
trained_state_dict = torch.load("./models/quentresnet18_weight8.pth", map_location='cpu')
brevitas_model.load_state_dict(trained_state_dict, strict=False)

<All keys matched successfully>

In [18]:
def inference(model, current_inp, device="cpu"):
    model.eval()
    current_inp = current_inp.to(device)

    with torch.no_grad():
        output = model(current_inp)
        _, predicted = torch.max(output, 1)
    
        # Convert output to a Python list for further use (e.g., JSON export)
        output_list = output.cpu().detach().numpy().tolist()

    return output_list, predicted

Now that we have the model in .onnx format, we can work with it using FINN. To import it into FINN, we'll use the ModelWrapper. It is a wrapper around the ONNX model which provides several helper functions to make it easier to work with the model.

In [21]:
import torch 
import torch.nn as nn 
import torch.nn.functional as F 
from brevitas.nn import QuantConv2d, QuantLinear 
from brevitas.export import export_brevitas_onnx 
from qonnx.core.modelwrapper import ModelWrapper 
from qonnx.transformation.fold_constants import FoldConstants 
from qonnx.transformation.infer_shapes import InferShapes 
from qonnx.transformation.infer_datatypes import InferDataTypes 
from qonnx.transformation.general import ( GiveUniqueNodeNames, GiveReadableTensorNames, RemoveStaticGraphInputs ) 
from finn.transformation.streamline import Streamline

In [22]:
import onnx
from onnx import numpy_helper
import numpy as np

def fold_quant_nodes_producing_weights(model_wrapper):
    model = model_wrapper.model
    init_map = {init.name: init for init in model.graph.initializer}
    nodes_to_remove = []
    new_inits = []

    for node in model.graph.node:
        if node.op_type == "Quant":
            quant_out = node.output[0]

            # Find Conv or Gemm nodes using this quant output as weight input
            is_used_as_weight = False
            for conv_node in model.graph.node:
                if conv_node.op_type in ["Conv", "Gemm"]:
                    if quant_out in conv_node.input:
                        is_used_as_weight = True
                        break

            if not is_used_as_weight:
                continue  # skip quant nodes not used as weights

            # Check if input to Quant is initializer (unquantized weights)
            quant_in_name = node.input[0]
            if quant_in_name not in init_map:
                print(f"Input to Quant node '{node.name}' ({quant_in_name}) not found as initializer!")
                continue

            # Load weight tensor as numpy
            weight_init = init_map[quant_in_name]
            weight_np = numpy_helper.to_array(weight_init)

            # TODO: Here you need actual quantization logic based on Quant node attributes
            # For now, assume quantized weights = original weights (no quantization applied)
            quantized_weight_np = weight_np  # Replace with actual quantization logic

            # Create new initializer with quant_out name
            new_init = numpy_helper.from_array(quantized_weight_np, quant_out)
            new_inits.append(new_init)

            # Update all Conv/Gemm nodes input from quant_out to use the new initializer (same name)
            # No change needed because name is the same, just now it will be a constant initializer

            # Mark Quant node for removal
            nodes_to_remove.append(node)

    # Remove Quant nodes from graph
    for node in nodes_to_remove:
        model.graph.node.remove(node)

    # Add new initializers
    model.graph.initializer.extend(new_inits)

    # Return updated model wrapper
    return ModelWrapper(model)

# --- Insert this folding step after your FoldConstants() but before Streamline() ---

model_wrapper = ModelWrapper(onnx_export_path)
model_wrapper.set_tensor_shape(model_wrapper.graph.input[0].name, [592, 3, 32, 32])
model_wrapper = model_wrapper.transform(InferShapes())
model_wrapper = model_wrapper.transform(InferDataTypes())
model_wrapper = model_wrapper.transform(GiveUniqueNodeNames())
model_wrapper = model_wrapper.transform(GiveReadableTensorNames())
model_wrapper = model_wrapper.transform(RemoveStaticGraphInputs())
model_wrapper = model_wrapper.transform(FoldConstants())

# Fold Quant nodes producing weights into initializers
model_wrapper = fold_quant_nodes_producing_weights(model_wrapper)

# Save cleaned-up model
verif_model_filename = "hardware/quantresnet18_weight8_files/quantresnet18_weight8_tidy.onnx"
model_wrapper.save(verif_model_filename)

print(f"Cleaned model saved to {verif_model_filename}")

Cleaned model saved to hardware/quantresnet18_weight8_files/quantresnet18_weight8_tidy.onnx


In [23]:
import torch
import numpy as np
import finn.core.onnx_exec as oxe

def inference_with_finn_onnx(current_inp: torch.Tensor, model_for_sim):
    """
    Run inference using a FINN-compatible ONNX model.

    Args:
        current_inp (torch.Tensor): Input tensor of shape (1, C, H, W).
        model_for_sim (ModelWrapper): Cleaned and transformed FINN ONNX model.

    Returns:
        logits (np.ndarray): Raw output from the model.
        predicted (np.ndarray): Predicted class index (argmax).
    """
    input_name = model_for_sim.graph.input[0].name
    output_name = model_for_sim.graph.output[0].name

    # Ensure correct input shape and type
    current_inp = current_inp.detach().cpu().numpy()
    if current_inp.ndim == 3:
        current_inp = np.expand_dims(current_inp, axis=0)  # Add batch dim
    current_inp = current_inp.astype(np.float32)

    # Build input dict and run inference
    input_dict = {input_name: current_inp}
    output_dict = oxe.execute_onnx(model_for_sim, input_dict)

    # Extract output and predicted class
    logits = output_dict[output_name]
    predicted = np.argmax(logits, axis=1)

    return logits, predicted

In [24]:
import torch

torch.manual_seed(42)

ok = 0
nok = 0
expected_batch_size = 592

for img, label_gt in testloader:
    img, label_gt = img.to(device), label_gt.to(device)

    if img.shape[0] == 0:
        continue  # skip empty batches

    original_batch_size = img.shape[0]

    # Pad or truncate to expected batch size
    if original_batch_size < expected_batch_size:
        padding_size = expected_batch_size - original_batch_size
        img = torch.cat([img, img[:padding_size]], dim=0)
        label_gt = torch.cat([label_gt, label_gt[:padding_size]], dim=0)
    elif original_batch_size > expected_batch_size:
        img = img[:expected_batch_size]
        label_gt = label_gt[:expected_batch_size]

    # Brevitas inference
    _, predicted = inference(brevitas_model, img, device=device)

    # FINN ONNX inference
    _, predicted_onnx = inference_with_finn_onnx(img, model_wrapper)
    predicted_onnx_tensor = torch.tensor(predicted_onnx, dtype=torch.long, device=device)

    # Compare predictions between Brevitas and FINN
    correct = predicted.eq(predicted_onnx_tensor).sum().item()
    incorrect = expected_batch_size - correct

    ok += correct
    nok += incorrect

# Final results
print(f"Total correct: {ok}")
print(f"Total incorrect: {nok}")
print(f"Agreement Accuracy: {100. * ok / (ok + nok):.2f}%")

Total correct: 10064
Total incorrect: 0
Agreement Accuracy: 100.00%


## Adding Pre- and Postprocessing <a id='prepost'></a>

In many cases, it's common to apply some preprocessing to the raw data in a machine learning framework prior to training. For image classification networks, this may include conversion of raw 8-bit RGB values into floating point values between 0 and 1. Similarly, at the output of the network some postprocessing may be performed during deployment, such as extracting the indices of the classifications with the largest value (top-K indices).

In FINN, we can bake some of these pre/postprocessing operatings into the graph, and in some cases these can be highly beneficial for performance by allowing our accelerator to directly consume raw data instead of going through CPU preprocessing. 

We'll demonstrate this for our small image classification network as follows. Brevitas preprocesses BNN-PYNQ network inputs with `torchvision.transforms.ToTensor()` [prior to training](https://github.com/Xilinx/brevitas/blob/master/src/brevitas_examples/bnn_pynq/trainer.py#L93), which converts 8-bit RGB values into floats between 0 and 1 by dividing the input by 255. We can achieve the same effect in FINN by exporting a single-node ONNX graph for division by 255 (which already exists as `finn.util.pytorch.ToTensor` and merging this with our original model. Finally, we're going to mark our input tensor as 8-bit to let FINN know which level of precision to use.

In [25]:
import torch
import torch.nn as nn

class ToTensorWrapper(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        # x: UINT8 NHWC (N, H, W, C)
        x = x.permute(0, 3, 1, 2).float() / 255.0  # NHWC -> NCHW, normalize
        return x


In [26]:
from brevitas.export import export_qonnx
from qonnx.util.cleanup import cleanup as qonnx_cleanup

# Updated dummy input with correct batch size
dummy_input = torch.randint(0, 256, (592, 32, 32, 3), dtype=torch.uint8)

# Export preprocessing model again
model = ToTensorWrapper()
preproc_path = "hardware/quantresnet18_weight8_files/quantresnet18_weight8_a1_with_preproc.onnx"
export_qonnx(model, dummy_input, preproc_path)
qonnx_cleanup(preproc_path, out_file=preproc_path)



In [27]:
from qonnx.core.modelwrapper import ModelWrapper
from finn.transformation.qonnx.convert_qonnx_to_finn import ConvertQONNXtoFINN
from qonnx.transformation.merge_onnx_models import MergeONNXModels
from qonnx.core.datatype import DataType

# Load preprocessing model and convert
pre_model = ModelWrapper(preproc_path)
pre_model = pre_model.transform(ConvertQONNXtoFINN())

# Load core model
core_model = ModelWrapper("hardware/quantresnet18_weight8_files/quantresnet18_weight8_tidy.onnx")

# Merge
core_model = core_model.transform(MergeONNXModels(pre_model))

# Set proper input shape and type
core_model.set_tensor_shape(core_model.graph.input[0].name, [592, 32, 32, 3])
core_model.set_tensor_datatype(core_model.graph.input[0].name, DataType["UINT8"])

# Save final model
core_model.save("hardware/quantresnet18_weight8_files/quantresnet18_weight8_final.onnx")




In [28]:
print("Core model input shape:", core_model.get_tensor_shape(core_model.graph.input[0].name))

Core model input shape: [592, 32, 32, 3]


In [29]:
from torchvision import datasets, transforms
import numpy as np
from finn.core.onnx_exec import execute_onnx

# Load CIFAR10 sample
transform = transforms.Compose([
    transforms.ToTensor(),  # Converts to [0,1] float
    transforms.ConvertImageDtype(torch.uint8)  # Simulate raw input as uint8
])
dataset = datasets.CIFAR10(root=".", train=False, download=True, transform=transform)

Files already downloaded and verified


In [30]:
# Create batch of 592 uint8 CIFAR images
imgs = []
labels = []
for i in range(592):
    img, label = dataset[i]
    img_np = (img.numpy() * 255).astype(np.uint8)
    imgs.append(img_np)
    labels.append(label)

# Stack and reshape
batch_np = np.stack(imgs, axis=0)  # Still (592, 3, 32, 32)
batch_np = np.transpose(batch_np, (0, 2, 3, 1))  # Now (592, 32, 32, 3)
input_dict = {core_model.graph.input[0].name: batch_np.astype(np.uint8)}

# Run inference
output_dict = execute_onnx(core_model, input_dict)
predictions = output_dict[core_model.graph.output[0].name]

print("Predictions shape:", predictions.shape)
print("First 10 predictions:", predictions[:10])


Predictions shape: (592, 10)
First 10 predictions: [[-0.3501871  -4.841082   -0.6999654   3.8095608   3.0655427   3.1280637
  -0.04158868 -1.3247869  -1.5421141   0.47182822]
 [ 3.2556887  -1.0761231  -0.02893056  2.819892   -1.1279145  -1.8868216
   0.21485473 -3.3927462   2.8908088   0.06721953]
 [-1.1760231  -6.011578    1.3738385   2.233866    4.4420485   2.1159492
   4.670891   -3.52532    -0.63752943 -2.1631691 ]
 [ 2.5881476  -6.219797    5.804773    4.8108077  -1.0688319   1.214709
   2.813208   -2.8741474  -1.1683418  -4.0391526 ]
 [ 1.1528944  -6.3383408  -0.11458682  4.0101094   1.5729      3.8940346
   4.177797   -2.6356783  -0.8261254  -3.0152466 ]
 [ 1.7585313  -3.8780274  -1.0171657   1.2420057  -1.336038    3.5041962
   5.716073   -4.68685     2.6686285  -2.3892045 ]
 [-0.92201155 -3.2746606  -3.262432    7.4282136  -0.65232545  0.49708307
   3.0408783  -1.1937056   2.8359857  -2.6004984 ]
 [ 2.7571666  -5.8486104   0.23284847  2.3388171   1.246488    1.2227845
   2.610

You can observe two changes in the graph above: a Div node has appeared in the beginning to perform the input preprocessing, and the global_in tensor now has a quantization annotation to mark it as an unsigned 8-bit value.

For the postprocessing we'll insert a TopK node for k=1 at the end of our graph. This will extract the index (class number) for the largest-valued output.


In [31]:
from qonnx.transformation.insert_topk import InsertTopK

# postprocessing: insert Top-1 node at the end
model = core_model.transform(InsertTopK(k=1))
chkpt_name = "hardware/quantresnet18_weight8_files/quantresnet18_weight8_final.onnx"
# tidy-up again
model = model.transform(InferShapes())
model = model.transform(FoldConstants())
model = model.transform(GiveUniqueNodeNames())
model = model.transform(GiveReadableTensorNames())
model = model.transform(InferDataTypes())
model = model.transform(RemoveStaticGraphInputs())


model.save(chkpt_name)


In [32]:
input_dict = {model.graph.input[0].name: batch_np.astype(np.uint8)}  # use uint8 now
output_dict = execute_onnx(model, input_dict)
top1_predictions = output_dict[model.graph.output[0].name]

print("Top-1 shape:", top1_predictions.shape)
print("Top-1 predictions (first 10):", top1_predictions[:10])


Top-1 shape: (592, 1)
Top-1 predictions (first 10): [[3]
 [0]
 [6]
 [2]
 [6]
 [6]
 [3]
 [8]
 [3]
 [3]]


## Streamlining <a id='streamline'></a>
Streamlining is a transformation containing several sub-transformations. The goal of streamlining is to eliminate floating point operations by moving them around, then collapsing them into one operation and in the last step transform them into multi-thresholding nodes. For more information on the theoretical background of this, see [this paper](https://arxiv.org/pdf/1709.04060).

Let's have a look at which sub-transformations `Streamline` consists of:

In [33]:
from finn.transformation.streamline.reorder import MoveScalarLinearPastInvariants
import finn.transformation.streamline.absorb as absorb
# Now you can streamline safely
model = ModelWrapper("hardware/quantresnet18_weight8_files/quantresnet18_weight8_final.onnx")
model = model.transform(Streamline())
model = model.transform(MoveScalarLinearPastInvariants())
# Save final model
onnx_final_path = "hardware/quantresnet18_weight8_files/quantresnet18_streamlined.onnx"
model.save(onnx_final_path)
print(f"✅ Saved streamlined model to {onnx_final_path}")

✅ Saved streamlined model to hardware/quantresnet18_weight8_files/quantresnet18_streamlined.onnx


You can see that the network has become simplified considerably compared to the previous step -- a lot of nodes have disappeared between the `MatMul` layers. 

**The current implementation of streamlining is highly network-specific and may not work for your network if its topology is very different than the example network here. We hope to rectify this in future releases.**

Our example network is a quantized network with 1-bit bipolar (-1, +1 values) precision, and we want FINN to implement them as XNOR-popcount operations [as described in the original FINN paper](https://arxiv.org/pdf/1612.07119). For this reason, after streamlining, the resulting bipolar matrix multiplications are converted into xnorpopcount operations. This transformation produces operations that are again collapsed and converted into thresholds. This procedure is shown below. 

In [34]:
from qonnx.transformation.bipolar_to_xnor import ConvertBipolarMatMulToXnorPopcount
from finn.transformation.streamline.round_thresholds import RoundAndClipThresholds
from qonnx.transformation.infer_data_layouts import InferDataLayouts
from qonnx.transformation.general import RemoveUnusedTensors

model = model.transform(ConvertBipolarMatMulToXnorPopcount())
model = model.transform(absorb.AbsorbAddIntoMultiThreshold())
model = model.transform(absorb.AbsorbMulIntoMultiThreshold())
model = model.transform(absorb.AbsorbScalarMulAddIntoTopK())
model = model.transform(RoundAndClipThresholds())
model = model.transform(InferDataLayouts())
model = model.transform(RemoveUnusedTensors())

model.save("hardware/quantresnet18_weight8_files/quantresnet18_ready_for_hw_conversion.onnx")

In [35]:
model = ModelWrapper("hardware/quantresnet18_weight8_files/quantresnet18_ready_for_hw_conversion.onnx")
print(set([node.op_type for node in model.graph.node]))

{'Transpose', 'GlobalAveragePool', 'Gemm', 'Reshape', 'Mul', 'Dropout', 'Relu', 'Add', 'Conv', 'TopK', 'Cast'}


Observe the pairs of `XnorPopcountmatMul` and `MultiThreshold` layers following each other -- this is the particular pattern that the next step will be looking for in order to convert them to hardware (HW) layers.

## 3. Conversion to HW layers <a id='hw_layers'></a>
Converts the nodes to HW layers, these layers are abstraction layers that do not directly correspond to an HLS or Verilog implementation but they will be converted in either one later in the flow. 

In [36]:
import finn.transformation.fpgadataflow.convert_to_hw_layers as to_hw

model = ModelWrapper("hardware/quantresnet18_weight8_files/quantresnet18_ready_for_hw_conversion.onnx")
model = model.transform(to_hw.InferBinaryMatrixVectorActivation())
# TopK to LabelSelect
model = model.transform(to_hw.InferLabelSelectLayer())
# input quantization (if any) to standalone thresholding
model = model.transform(to_hw.InferThresholdingLayer())
model.save("hardware/quantresnet18_weight8_files/quantresnet18_hw_layers.onnx")

model = ModelWrapper("hardware/quantresnet18_weight8_files/quantresnet18_hw_layers.onnx")
print(set([node.op_type for node in model.graph.node]))

{'Transpose', 'GlobalAveragePool', 'Gemm', 'Reshape', 'Mul', 'Dropout', 'Relu', 'Add', 'Conv', 'TopK', 'Cast'}


In [37]:
import finn.transformation.fpgadataflow.convert_to_hw_layers as to_hw
print(dir(to_hw))



In [38]:
from finn.transformation.fpgadataflow.convert_to_hw_layers import (
    InferBinaryMatrixVectorActivation,
    InferQuantizedMatrixVectorActivation,
    InferLabelSelectLayer,
    InferThresholdingLayer,
    InferStreamingMaxPool,
    InferGlobalAccPoolLayer,
    InferStreamingEltwise,
    InferConvInpGen
)

In [39]:
from finn.transformation.fpgadataflow.convert_to_hw_layers import (
    InferBinaryMatrixVectorActivation,
    InferQuantizedMatrixVectorActivation,
    InferLabelSelectLayer,
    InferThresholdingLayer,
    InferStreamingMaxPool,
    InferGlobalAccPoolLayer,
    InferStreamingEltwise,
    InferConvInpGen
)

model = ModelWrapper("hardware/quantresnet18_weight8_files/quantresnet18_hw_layers.onnx")

model = model.transform(InferBinaryMatrixVectorActivation())
model = model.transform(InferQuantizedMatrixVectorActivation())
model = model.transform(InferLabelSelectLayer())
model = model.transform(InferThresholdingLayer())
model = model.transform(InferStreamingMaxPool())
model = model.transform(InferGlobalAccPoolLayer())
model = model.transform(InferStreamingEltwise())
model = model.transform(InferConvInpGen())

model.save("hardware/quantresnet18_weight8_files/quantresnet18_hw_layers_updated.onnx")

### Creating a Dataflow Partition <a id='dataflow_partition'></a>

In the graph above, you can see that there is a mixture of FINN HW layers (`MVAU` and `Thresholding`) with one regular ONNX layers (Reshape). To create a bitstream, FINN needs a model with only HW layers. In order to achieve this, we will use the `CreateDataflowPartition` transformation to create a "dataflow partition" in this graph, separating out the HLS layers into another model, and replacing them with a placeholder layer called StreamingDataflowPartition.

In [40]:
from finn.transformation.fpgadataflow.create_dataflow_partition import CreateDataflowPartition

model = ModelWrapper("hardware/quantresnet18_weight8_files/quantresnet18_hw_layers_updated.onnx")
parent_model = model.transform(CreateDataflowPartition())
parent_model.save("hardware/quantresnet18_weight8_files/quantresnet18_dataflow_parent.onnx")

In [41]:
print(set([node.op_type for node in parent_model.graph.node]))

{'Transpose', 'GlobalAveragePool', 'Gemm', 'Reshape', 'Mul', 'Dropout', 'Relu', 'Add', 'Conv', 'TopK', 'Cast'}


In [42]:
sdp_nodes = parent_model.get_nodes_by_op_type("StreamingDataflowPartition")
if len(sdp_nodes) == 0:
    print("❌ No StreamingDataflowPartition node found in the model.")
else:
    sdp_node = getCustomOp(sdp_nodes[0])
    dataflow_model_filename = sdp_node.get_nodeattr("model")
    print("✅ Found dataflow model:", dataflow_model_filename)

❌ No StreamingDataflowPartition node found in the model.


In [43]:
model = ModelWrapper("hardware/quantresnet18_weight8_files/quantresnet18_hw_layers_updated.onnx")

In [44]:
for node in model.graph.node:
        print(node.name, node.op_type)

Transpose_0 Transpose
Cast_0 Cast
Conv_0 Conv
Mul_0 Mul
Add_0 Add
Relu_0 Relu
Dropout_0 Dropout
Conv_1 Conv
Mul_1 Mul
Add_1 Add
Relu_1 Relu
Dropout_1 Dropout
Conv_2 Conv
Mul_2 Mul
Add_2 Add
Dropout_2 Dropout
Add_3 Add
Relu_2 Relu
Dropout_3 Dropout
Conv_3 Conv
Mul_3 Mul
Add_4 Add
Relu_3 Relu
Dropout_4 Dropout
Conv_4 Conv
Mul_4 Mul
Add_5 Add
Dropout_5 Dropout
Add_6 Add
Relu_4 Relu
Dropout_6 Dropout
Conv_5 Conv
Conv_6 Conv
Mul_5 Mul
Mul_6 Mul
Add_7 Add
Add_8 Add
Relu_5 Relu
Dropout_7 Dropout
Conv_7 Conv
Mul_7 Mul
Add_9 Add
Dropout_8 Dropout
Add_10 Add
Relu_6 Relu
Dropout_9 Dropout
Conv_8 Conv
Mul_8 Mul
Add_11 Add
Relu_7 Relu
Dropout_10 Dropout
Conv_9 Conv
Mul_9 Mul
Add_12 Add
Dropout_11 Dropout
Add_13 Add
Relu_8 Relu
Dropout_12 Dropout
Conv_10 Conv
Conv_11 Conv
Mul_10 Mul
Mul_11 Mul
Add_14 Add
Add_15 Add
Relu_9 Relu
Dropout_13 Dropout
Conv_12 Conv
Mul_12 Mul
Add_16 Add
Dropout_14 Dropout
Add_17 Add
Relu_10 Relu
Dropout_15 Dropout
Conv_13 Conv
Mul_13 Mul
Add_18 Add
Relu_11 Relu
Dropout_16 

In [45]:
from finn.util.basic import show_net_structure

ImportError: cannot import name 'show_net_structure' from 'finn.util.basic' (/home/pamela/finn/src/finn/util/basic.py)

In [None]:
for node in model.graph.node:
    if node.op_type in ["Conv","Gemm"]:
        print(node.name, node.input)

In [46]:
for init in model.model.graph.initializer:
    if "Conv" in init.name:
        print(init.name, init.data_type)

Conv_0_param1 1
Conv_1_param1 1
Conv_2_param1 1
Conv_3_param1 1
Conv_4_param1 1
Conv_5_param1 1
Conv_7_param1 1
Conv_6_param1 1
Conv_8_param1 1
Conv_9_param1 1
Conv_10_param1 1
Conv_12_param1 1
Conv_11_param1 1
Conv_13_param1 1
Conv_14_param1 1
Conv_15_param1 1
Conv_17_param1 1
Conv_16_param1 1
Conv_18_param1 1
Conv_19_param1 1
Conv_0_param0 1
Conv_1_param0 1
Conv_2_param0 1
Conv_3_param0 1
Conv_4_param0 1
Conv_5_param0 1
Conv_7_param0 1
Conv_6_param0 1
Conv_8_param0 1
Conv_9_param0 1
Conv_10_param0 1
Conv_12_param0 1
Conv_11_param0 1
Conv_13_param0 1
Conv_14_param0 1
Conv_15_param0 1
Conv_17_param0 1
Conv_16_param0 1
Conv_18_param0 1
Conv_19_param0 1


In [15]:
for inp in model.graph.input:
    print(inp.name, model.get_tensor_datatype(inp.name))

AttributeError: 'ResNet' object has no attribute 'graph'