In [1]:
import torch
from torchinfo import summary

from torch.nn import BatchNorm1d
from torch.nn import BatchNorm2d
from torch.nn import MaxPool2d
from torch.nn import Module
from torch.nn import ModuleList

from brevitas.core.restrict_val import RestrictValueType
from brevitas.nn import QuantIdentity
from brevitas.nn import QuantConv2d
from brevitas.nn import QuantReLU
from brevitas.nn import TruncAvgPool2d
from brevitas.nn import QuantLinear

from brevitas.quant import TruncTo8bit


from brevitas.quant import Int8ActPerTensorFloat # For Quant ADD node
from common_imagenet import CommonIntWeightPerTensorQuant
from common_imagenet import CommonIntWeightPerChannelQuant
from common_imagenet import CommonUintActQuant
from common_imagenet import CommonIntActQuant # For initial Q1.7 Identity Layer
from tensor_norm import TensorNorm

from brevitas.export import export_qonnx

# Custom Quantizers

In [2]:
class MyWeightsQuant_PerTensor(CommonIntWeightPerTensorQuant):
    restrict_scaling_type = RestrictValueType.POWER_OF_TWO

class MyWeightsQuant_PerChannel(CommonIntWeightPerChannelQuant):
    restrict_scaling_type = RestrictValueType.POWER_OF_TWO

class MyReLUQuant(CommonUintActQuant):
    restrict_scaling_type = RestrictValueType.POWER_OF_TWO


# Tiny Model

In [3]:
class TINY_RESNET(Module):

    def __init__(self, 
                 num_classes = 2, 
                 weight_bit_width = 4,
                 act_bit_width = 4, 
                 in_bit_width = 8, 
                 in_channels = 3):
        super(TINY_RESNET, self).__init__()
        
        self.conv_features = ModuleList()
        self.conv_branch = ModuleList()
        self.linear_features = ModuleList()

        # Input 230x230x3
        self.conv_features.append(QuantIdentity( # for Q1.7 input format -> sign.7bits
            act_quant = CommonIntActQuant,
            bit_width = in_bit_width,
            min_val = -1.0,
            max_val = 1.0 - 2.0 ** (-7),
            narrow_range = False, 
            restrict_scaling_type = RestrictValueType.POWER_OF_TWO))

        # CNNBlock 224x224
            # conv1
        self.conv_features.append(
            QuantConv2d(
                kernel_size=3, stride=1, padding=1,
                in_channels=in_channels,
                out_channels=12,
                bias=False,
                weight_quant=MyWeightsQuant_PerTensor,
                weight_bit_width=weight_bit_width))
        self.conv_features.append(BatchNorm2d(12))
        self.conv_features.append(
            QuantReLU(
                act_quant=MyReLUQuant,
                bit_width=act_bit_width, 
                return_quant_tensor=True))
        
        self.conv_features.append(MaxPool2d(kernel_size=2, stride=2))

        # CNNBlock 112x112
            # conv2
        self.conv_branch.append(
            QuantConv2d(
                kernel_size=3, stride=1, padding=1,
                in_channels=12,
                out_channels=12,
                bias=False,
                weight_quant=MyWeightsQuant_PerTensor,
                weight_bit_width=weight_bit_width))
        self.conv_branch.append(BatchNorm2d(12))
        self.conv_branch.append(
            QuantReLU(
                act_quant=MyReLUQuant,
                bit_width=act_bit_width, 
                return_quant_tensor=True))

        # Convs Branch
        self.quant_MP_0_0 = QuantIdentity( 
            act_quant = Int8ActPerTensorFloat,
            bit_width = 4, 
            restrict_scaling_type = RestrictValueType.POWER_OF_TWO)
        self.quant_MP_0_1 = QuantIdentity( 
            act_quant = self.quant_MP_0_0.act_quant,
            bit_width = 4, 
            restrict_scaling_type = RestrictValueType.POWER_OF_TWO,
            return_quant_tensor=True)
        # Direct Branch
        self.quant_MP_1 = QuantIdentity( 
            act_quant = self.quant_MP_0_0.act_quant,
            bit_width = 4, 
            restrict_scaling_type = RestrictValueType.POWER_OF_TWO,
            return_quant_tensor=True)

        self.quant_add = QuantIdentity( 
            act_quant = Int8ActPerTensorFloat,
            bit_width = 4, 
            restrict_scaling_type = RestrictValueType.POWER_OF_TWO,
            return_quant_tensor=True)
        
        self.avg_pool = TruncAvgPool2d(
                kernel_size=112,  
                trunc_quant=TruncTo8bit,
                float_to_int_impl_type='FLOOR')

        # Linear 1
        self.linear_features.append(
            QuantLinear(
                in_features=12,
                out_features=8,
                bias=False,
                weight_quant=MyWeightsQuant_PerTensor,
                weight_bit_width=weight_bit_width))
        self.linear_features.append(BatchNorm1d(8))
        self.linear_features.append(
            QuantReLU(
                act_quant=MyReLUQuant,
                bit_width=act_bit_width, 
                return_quant_tensor=False))

        # Linear 2
        self.linear_features.append(
            QuantLinear(
                in_features=8,
                out_features=2,
                bias=False,
                weight_quant=MyWeightsQuant_PerTensor,
                weight_bit_width=weight_bit_width))
        self.linear_features.append(TensorNorm())

    def forward(self, x):
        x = 2.0 * x - torch.tensor([1.0], device=x.device)
        for mod in self.conv_features:
            x = mod(x)

        x_res = self.quant_MP_1(x)
        x_conv = self.quant_MP_0_0(x)
        x_conv = self.conv_branch[0](x_conv)
        x_conv = self.conv_branch[1](x_conv)
        x_conv = self.conv_branch[2](x_conv)
        x_conv = self.quant_MP_0_1(x_conv)
        
        x = x_conv + x_res
        x = self.quant_add(x)
        
        x = self.avg_pool(x)
        
        x = x.view(x.shape[0], -1)
        for mod in self.linear_features:
            x = mod(x)
        return x

In [4]:
model_qnn = TINY_RESNET().to('cpu')

  warn('Keyword arguments are being passed but they not being used.')


In [7]:
input_shape = (1, 3, 224, 224)
# print(summary(model_qnn, input_size=input_shape))

In [8]:
models_folder = './step_by_step_tiny'
model_qnn_filename = models_folder + '/TINY_Resnet__QONNX.onnx' 

In [9]:
model_qnn.eval();
export_qonnx(model_qnn, torch.randn(input_shape), model_qnn_filename);

# FINN Flow

## Load Model and View

In [119]:
from finn.util.visualization import showSrc, showInNetron
from qonnx.util.cleanup import cleanup as qonnx_cleanup

In [120]:
showInNetron(model_qnn_filename)

Stopping http://0.0.0.0:8083
Serving './step_by_step_tiny/TINY_Resnet__QONNX.onnx' at http://0.0.0.0:8083


## Clean

In [121]:
qonnx_clean_filename = models_folder + '/01_clean.onnx'
qonnx_cleanup(model_qnn_filename, out_file=qonnx_clean_filename)

In [122]:
showInNetron(qonnx_clean_filename)

Stopping http://0.0.0.0:8083
Serving './step_by_step_tiny/01_clean.onnx' at http://0.0.0.0:8083


## Convert to FINN

In [123]:
from qonnx.core.modelwrapper import ModelWrapper

In [124]:
from finn.transformation.qonnx.convert_qonnx_to_finn import ConvertQONNXtoFINN
from qonnx.transformation.infer_shapes import InferShapes
from qonnx.transformation.fold_constants import FoldConstants
from qonnx.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames, RemoveStaticGraphInputs

In [125]:
model = ModelWrapper(qonnx_clean_filename)
model = model.transform(ConvertQONNXtoFINN())
model = model.transform(InferShapes())
model = model.transform(FoldConstants())
model = model.transform(GiveUniqueNodeNames())
model = model.transform(GiveReadableTensorNames())
model = model.transform(RemoveStaticGraphInputs())

In [126]:
finn_tidy = models_folder + '/02_finn_tidy.onnx'
model.save(finn_tidy)

In [127]:
showInNetron(finn_tidy)

Stopping http://0.0.0.0:8083
Serving './step_by_step_tiny/02_finn_tidy.onnx' at http://0.0.0.0:8083


## Preprocess

In [128]:
from finn.util.pytorch import ToTensor
from qonnx.transformation.merge_onnx_models import MergeONNXModels
from qonnx.core.datatype import DataType
from qonnx.transformation.infer_datatypes import InferDataTypes

In [129]:
model = ModelWrapper(finn_tidy)
global_inp_name = model.graph.input[0].name
ishape = model.get_tensor_shape(global_inp_name)
# preprocessing: torchvision's ToTensor divides uint8 inputs by 255
totensor_pyt = ToTensor()
chkpt_preproc_name = models_folder + "/prepro_node.onnx"
export_qonnx(totensor_pyt, torch.randn(ishape), chkpt_preproc_name)
qonnx_cleanup(chkpt_preproc_name, out_file=chkpt_preproc_name)
pre_model = ModelWrapper(chkpt_preproc_name)
pre_model = pre_model.transform(ConvertQONNXtoFINN())

# join preprocessing and core model
model = model.transform(MergeONNXModels(pre_model))
# add input quantization annotation: UINT8 for all BNN-PYNQ models
global_inp_name = model.graph.input[0].name
model.set_tensor_datatype(global_inp_name, DataType["UINT8"])



### Tidy again

In [130]:
model = model.transform(InferShapes())
model = model.transform(FoldConstants())
model = model.transform(GiveUniqueNodeNames())
model = model.transform(GiveReadableTensorNames())
model = model.transform(InferDataTypes())
model = model.transform(RemoveStaticGraphInputs())

In [131]:
finn_prepro = models_folder + '/03_finn_prepro.onnx'
model.save(finn_prepro)

In [132]:
showInNetron(finn_prepro)

Stopping http://0.0.0.0:8083
Serving './step_by_step_tiny/03_finn_prepro.onnx' at http://0.0.0.0:8083


## Streamline

In [152]:
from qonnx.transformation.lower_convs_to_matmul import LowerConvsToMatMul

from qonnx.transformation.change_datalayout import ChangeDataLayoutQuantAvgPool2d
from qonnx.transformation.infer_data_layouts import InferDataLayouts
from qonnx.transformation.general import RemoveUnusedTensors

from finn.transformation.streamline import Streamline
import finn.transformation.streamline.absorb as absorb
from finn.transformation.streamline.reorder import MoveScalarLinearPastInvariants
from finn.transformation.streamline.reorder import MakeMaxPoolNHWC
from finn.transformation.streamline.reorder import MoveLinearPastEltwiseAdd

In [154]:
model = ModelWrapper(finn_prepro)
model = model.transform(MoveLinearPastEltwiseAdd())
model = model.transform(absorb.AbsorbAddIntoMultiThreshold())
model = model.transform(absorb.AbsorbMulIntoMultiThreshold())

model = model.transform(MoveScalarLinearPastInvariants())
model = model.transform(Streamline())
model = model.transform(LowerConvsToMatMul())
model = model.transform(MakeMaxPoolNHWC())
model = model.transform(ChangeDataLayoutQuantAvgPool2d())
model = model.transform(absorb.AbsorbTransposeIntoMultiThreshold())

model = model.transform(Streamline())
model = model.transform(InferDataLayouts())
model = model.transform(RemoveUnusedTensors())



In [155]:
finn_streamline = models_folder + '/04_finn_streamline.onnx'
model.save(finn_streamline)

In [156]:
showInNetron(finn_streamline)

Stopping http://0.0.0.0:8083
Serving './step_by_step_tiny/04_finn_streamline.onnx' at http://0.0.0.0:8083


# To HW Layers

In [144]:
import finn.transformation.fpgadataflow.convert_to_hw_layers as to_hw
from finn.transformation.fpgadataflow.create_dataflow_partition import (
    CreateDataflowPartition,
)
from finn.transformation.move_reshape import RemoveCNVtoFCFlatten

from qonnx.custom_op.registry import getCustomOp

In [157]:
model = ModelWrapper(finn_streamline)

model = model.transform(to_hw.InferAddStreamsLayer())
model = model.transform(to_hw.InferQuantizedMatrixVectorActivation())

# input quantization (if any) to standalone thresholding
model = model.transform(to_hw.InferThresholdingLayer())
model = model.transform(to_hw.InferPool())
model = model.transform(to_hw.InferStreamingMaxPool())
model = model.transform(to_hw.InferConvInpGen())

# get rid of Reshape(-1, 1) operation between hw nodes 
model = model.transform(RemoveCNVtoFCFlatten())

# get rid of Tranpose -> Tranpose identity seq
model = model.transform(absorb.AbsorbConsecutiveTransposes())

# infer tensor data layouts
model = model.transform(InferDataLayouts())

model = model.transform(Streamline())

AssertionError: MultiThreshold_5: Signed output requires actval < 0

In [149]:
finn_hw_layers = models_folder + '/05_fin_hw_layers.onnx'
model.save(finn_hw_layers)

In [150]:
showInNetron(finn_hw_layers)

Stopping http://0.0.0.0:8083
Serving './step_by_step_tiny/05_fin_hw_layers.onnx' at http://0.0.0.0:8083
