In [None]:
!pip install brevitas
!pip install -U netron

# PyTorch libraries and modules
import torch
import torch.nn as nn
from torch.optim import Adam
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader, Dataset


import pandas as pd
import numpy as np
from sklearn.utils import resample

from torch.nn import Module, ModuleList, BatchNorm2d, MaxPool2d, BatchNorm1d, ReLU, Softmax, CrossEntropyLoss, Sequential, Dropout, Conv2d, Linear

from brevitas.nn import QuantConv2d, QuantIdentity, QuantLinear, QuantReLU
from brevitas.core.restrict_val import RestrictValueType
from tensor_norm import TensorNorm
from common import CommonWeightQuant, CommonActQuant

from qonnx.core.modelwrapper import ModelWrapper
from qonnx.custom_op.registry import getCustomOp

In [None]:
CNV_OUT_CH_POOL = [(21, False), (21, True), (21, False)] #[(21, False), (21, True), (21, False)]
INTERMEDIATE_FC_FEATURES = [(3549, 16), (16, 16)] #[(3549, 16), (16, 16)]
LAST_FC_IN_FEATURES = 16
LAST_FC_PER_OUT_CH_SCALING = False
POOL_SIZE = 2
KERNEL_SIZE = 6

MixPrecisionBits = 4  
ConvPrecisonBits = 4  
LinearPrecisonBits = 4

class CNV(Module):

    def __init__(self, num_classes, weight_bit_width, act_bit_width, in_bit_width, in_ch):
        super(CNV, self).__init__()

        self.conv_features = ModuleList()
        self.linear_features = ModuleList()

        self.conv_features.append(QuantIdentity( # for Q1.7 input format
            act_quant=CommonActQuant,
            bit_width=in_bit_width,
            min_val=- 1.0,
            max_val=1.0 - 2.0 ** (-7),
            narrow_range=True,  ###If True implements the value in a range he range for weights and biases
            #will be from -2^(N-1) + 1 to 2^(N-1), where N is the bit width. This is different from the default 
            #range of -2^(N-1) to 2^(N-1) when narrow_range is False.
            #narrow_range = True makes the hardware inference more efficient
            restrict_scaling_type=RestrictValueType.POWER_OF_TWO))

        for out_ch, is_pool_enabled in CNV_OUT_CH_POOL:
            self.conv_features.append(QuantConv2d(kernel_size=KERNEL_SIZE, in_channels=in_ch, out_channels=out_ch,
                bias=True, padding=4, weight_quant=CommonWeightQuant, weight_bit_width=ConvPrecisonBits))#made bias=False
            in_ch = out_ch
            self.conv_features.append(BatchNorm2d(in_ch, eps=1e-4))
            self.conv_features.append(QuantIdentity(act_quant=CommonActQuant,bit_width=MixPrecisionBits))#MultiThreshold123
            if is_pool_enabled:
                self.conv_features.append(MaxPool2d(kernel_size=2))

        for in_features, out_features in INTERMEDIATE_FC_FEATURES:
            self.linear_features.append(QuantLinear(in_features=in_features, out_features=out_features, bias=True,
                weight_quant=CommonWeightQuant, weight_bit_width=LinearPrecisonBits))
            self.linear_features.append(BatchNorm1d(out_features, eps=1e-4))
            self.linear_features.append(QuantIdentity(act_quant=CommonActQuant,bit_width=LinearPrecisonBits))#MultiThreshold45

        self.linear_features.append(QuantLinear(in_features=LAST_FC_IN_FEATURES, out_features=num_classes, bias=False,
            weight_quant=CommonWeightQuant, weight_bit_width=LinearPrecisonBits))
        self.linear_features.append(TensorNorm())
        
        for m in self.modules():
            if isinstance(m, QuantConv2d) or isinstance(m, QuantLinear):
                torch.nn.init.uniform_(m.weight.data, -1, 1)
                #print(f"Weight Data Convolution {m+1}", m.weight.data)


    def clip_weights(self, min_val, max_val):
        for mod in self.conv_features:
            if isinstance(mod, QuantConv2d):
                mod.weight.data.clamp_(min_val, max_val)
                #print(f"Weight Data Convolution {mod+1}", mod.weight.data)
        for mod in self.linear_features:
            if isinstance(mod, QuantLinear):
                mod.weight.data.clamp_(min_val, max_val)
                #print(f"Weight Data Convolution {mod+1}", mod.weight.data)

    def forward(self, x):
        #print("Data Feeded:",x)
        x = 2.0 * x - torch.tensor([1.0], device=x.device)
        #print("Data Processesd(2x-1):",x)
        for mod in self.conv_features:
            x = mod(x)
            #print(f"Data After Convolution Feature {mod+1}:",x)
        x = x.view(x.shape[0], -1)
        #print("Data After Flatten:",x)
        for mod in self.linear_features:
            x = mod(x)
            #print(f"Data After Linear Feature {mod+1}:",x)
        return x

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CNV(num_classes=5, weight_bit_width=1, act_bit_width=1, in_bit_width=8, in_ch=1)

In [None]:
xtrain_reshape = torch.load("xtrain25000014x14float32.pth")
ytrain_tensor = torch.load("ytrain250000int64.pth")
xval_reshape = torch.load("xval4412014x14float32.pth")
yval_tensor = torch.load("yval44120int64.pth")
xtest_reshape = torch.load("xtest4527014x14float32.pth")
ytest_tensor = torch.load("ytest45270int64.pth")

In [None]:
class Data(Dataset):
    def __init__(self, X, y):
        self.X = X.unsqueeze(1)
        self.y = y
        self.len = self.X.shape[0]

    def __getitem__(self, index):
        return self.X[index], self.y[index]

    def __len__(self):
        return self.len

In [None]:
batch_size = 100
train_data = Data(xtrain_reshape, ytrain_tensor)
train_dataloader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True)

val_data = Data(xval_reshape, yval_tensor)
val_dataloader = DataLoader(dataset=val_data, batch_size=batch_size, shuffle=True)

test_data = Data(xtest_reshape, ytest_tensor)
test_dataloader = DataLoader(dataset=test_data, batch_size=batch_size, shuffle=True)

# # Check it's working
# for batch, (X, y) in enumerate(train_dataloader):
#     print(f"Batch: {batch+1}")
#     print(f"XTrain shape: {X.shape}")
#     print(f"yTrain shape: {y.shape}")
#     break
# for batch, (X, y) in enumerate(val_dataloader):
#     print(f"Batch: {batch+1}")
#     print(f"XVal: {X.shape}")
#     print(f"yVal: {y.shape}")
#     break
# for batch, (X, y) in enumerate(test_dataloader):
#     print(f"Batch: {batch+1}")
#     print(f"XTest: {X.shape}")
#     print(f"yTest: {y.shape}")
#     break

In [None]:
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = CrossEntropyLoss()
local_ep = 10

In [None]:
from tqdm import tqdm

epoch_loss = []
batch_loss = []

for iter in range(local_ep):
    model.train()
    criterion.train()
    
    train_correct = 0
    train_total = 0
    train_batch_loss =[]
    
    total_batches = len(train_dataloader)
    
    progress_bar = tqdm(total=total_batches, desc="Processing Local Epoch(s)", unit="batch", position=0, leave=True)
    
    for batch_idx, (xtrain, ytrain) in enumerate(train_dataloader):
        xtrain, ytrain = xtrain.to(device), ytrain.to(device)
        model_preds = model(xtrain)
        _, pred_labels = torch.max(model_preds, 1)
        train_correct += torch.sum(pred_labels == ytrain).item()
        train_total += ytrain.size(0)
        loss = criterion(model_preds, ytrain)
        train_batch_loss.append(loss.item())
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        model.clip_weights(-1, 1)
        
        batch_loss.append(loss.item())
                
        progress_bar.update(1)
        progress_bar.set_postfix(batch=batch_idx + 1, refresh=True)
        
    progress_bar.close()
    
    epoch_loss.append(sum(batch_loss)/len(batch_loss))
    
    model.eval() 
    criterion.eval()
    
    val_correct = 0
    val_total = 0
    val_batch_loss = []       
    
    with torch.no_grad():
        for batch_idx, (xval, yval) in enumerate(val_dataloader):
            xval, yval = xval.to(device), yval.to(device)
            val_model_preds = model(xval)
            val_loss = criterion(val_model_preds, yval)
            val_batch_loss.append(val_loss.item())
            _, val_pred_labels = torch.max(val_model_preds, 1)
            val_correct += torch.sum(val_pred_labels == yval).item()
            val_total += yval.size(0)
    
    # Calculate and print average training and validation losses and accuracies
    avg_train_loss = sum(train_batch_loss) / len(train_batch_loss)
    train_accuracy = train_correct / train_total if train_total > 0 else 0.0
    avg_val_loss = sum(val_batch_loss) / len(val_batch_loss)
    val_accuracy = val_correct / val_total if val_total > 0 else 0.0
    print('Local Epoch: ', iter + 1,
          'Training Loss: ', avg_train_loss, 'Training Accuracy:{:.4f}%'.format(100*train_accuracy))
    print('Validation Loss: ', avg_val_loss, 'Validation Accuracy:{:.4f}%'.format(100*val_accuracy))
    model.train() 

In [None]:
build_dir = 'Accel7/'

In [None]:
torch.save(model.state_dict(), build_dir +'Model_State_dict.pth')
torch.save(optimizer.state_dict(), build_dir +"Optimizer_state_dict.pth")
print ("Loss", sum(epoch_loss) / len(epoch_loss))

In [None]:
model.load_state_dict(torch.load(build_dir +'Model_State_dict.pth'))

In [None]:
def test_inference(model, test_dataloader):
    model.eval()
    Tall_true_labels = []
    Tall_predicted_labels = []

    loss, total, correct = 0.0, 0.0, 0.0
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    criterion = nn.CrossEntropyLoss().to(device)

    for batch_idx, (xtest, ytest) in enumerate(test_dataloader):

        xtest, ytest = xtest.to(device), ytest.to(device)

        outputs = model(xtest)
        batch_loss = criterion(outputs, ytest)
        loss += batch_loss.item()

        pred = outputs.data.argmax(1, keepdim=True)
        correct += pred.eq(ytest.data.view_as(pred)).sum()
        total += len(ytest)
    accuracy = 100. * correct.float() / total                    
    loss = loss/total

    return accuracy, loss

test_inference(model, test_dataloader)

In [None]:
import brevitas.onnx as bo
bo.export_finn_onnx(model, (1, 1, 14, 14), build_dir +"export.onnx");

In [None]:
from finn.util.visualization import showInNetron
showInNetron(build_dir + "export.onnx")

In [None]:
import onnx
from finn.util.test import get_test_model_trained
import brevitas.onnx as bo
from qonnx.core.modelwrapper import ModelWrapper
from qonnx.transformation.infer_shapes import InferShapes
from qonnx.transformation.infer_datatypes import InferDataTypes
from qonnx.transformation.fold_constants import FoldConstants
from qonnx.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames, RemoveStaticGraphInputs
from qonnx.transformation.infer_datatypes import InferDataTypes

model = ModelWrapper(build_dir + "export.onnx")
model = model.transform(InferShapes())
model = model.transform(FoldConstants())
model.save(build_dir + "tidy.onnx")

In [None]:
from finn.util.visualization import showInNetron
showInNetron(build_dir + "tidy.onnx")

In [None]:
from finn.util.pytorch import ToTensor
from qonnx.transformation.merge_onnx_models import MergeONNXModels
from qonnx.core.datatype import DataType
from qonnx.transformation.insert_topk import InsertTopK
from qonnx.transformation.infer_datatypes import InferDataTypes
import brevitas.onnx as bo
from qonnx.core.modelwrapper import ModelWrapper

model = ModelWrapper(build_dir+"tidy.onnx")
global_inp_name = model.graph.input[0].name
global_inp_name = model.graph.input[0].name
ishape = model.get_tensor_shape(global_inp_name)
# preprocessing: torchvision's ToTensor divides uint8 inputs by 255
totensor_pyt = ToTensor()                                              
chkpt_preproc_name = build_dir+"preproc.onnx"
bo.export_finn_onnx(totensor_pyt, ishape, chkpt_preproc_name)

# join preprocessing and core model
pre_model = ModelWrapper(chkpt_preproc_name)
model = model.transform(MergeONNXModels(pre_model))
# # add input quantization annotation: UINT8 for all BNN-PYNQ models
global_inp_name = model.graph.input[0].name
model.set_tensor_datatype(global_inp_name, DataType["UINT8"])


from qonnx.transformation.infer_shapes import InferShapes
from qonnx.transformation.fold_constants import FoldConstants

model = model.transform(InsertTopK(k=1))
model = model.transform(InferShapes())
model = model.transform(FoldConstants())
model = model.transform(InferDataTypes())


model.save(build_dir + "preproc2.onnx")

In [None]:
from finn.util.visualization import showInNetron
showInNetron(build_dir + "preproc2.onnx")

In [None]:
import finn.builder.build_dataflow as build
import finn.builder.build_dataflow_config as build_cfg
import os
import shutil

model_file = build_dir+"preproc2.onnx"

rtlsim_output_dir = build_dir+"Buildv2"

#Delete previous run results if exist
if os.path.exists(rtlsim_output_dir):
    shutil.rmtree(rtlsim_output_dir)
    print("Previous run results deleted!")

cfg_stitched_ip = build.DataflowBuildConfig(
    output_dir          = rtlsim_output_dir,
    mvau_wwidth_max     = 10,
    target_fps          = 1000000,
    synth_clk_period_ns = 10.0,
    fpga_part           = "xczu9eg-ffvb1156-2-e",
    board               = "ZCU102",
    shell_flow_type     = build_cfg.ShellFlowType.VIVADO_ZYNQ,

    folding_config_file = build_dir+"final_hw_config_Accel7.json",

    auto_fifo_depths    = False,
    
#     steps=["step_apply_folding_config",
#            "step_generate_estimate_reports",
#            "step_hls_codegen",
#            "step_hls_ipgen",
#            "step_set_fifo_depths",
#            "step_create_stitched_ip",
#            "step_measure_rtlsim_performance",
#            "step_out_of_context_synthesis",
#            "step_synthesize_bitfile",
#            "step_make_pynq_driver",
#           ],
    generate_outputs=[
        build_cfg.DataflowOutputType.STITCHED_IP,
        build_cfg.DataflowOutputType.RTLSIM_PERFORMANCE,
        build_cfg.DataflowOutputType.OOC_SYNTH,
        build_cfg.DataflowOutputType.BITFILE,
        build_cfg.DataflowOutputType.PYNQ_DRIVER,
        build_cfg.DataflowOutputType.DEPLOYMENT_PACKAGE]
)
 

#     auto_fifo_depths    = False,

In [None]:
%%time
build.build_dataflow_cfg(model_file, cfg_stitched_ip)