# Train a Quantized MLP on UNSW-NB15 with Brevitas

<font color="red">**Live FINN tutorial:** We recommend clicking **Cell -> Run All** when you start reading this notebook for "latency hiding".</font>

In this notebook, we will show how to create, train and export a quantized Multi Layer Perceptron (MLP) with quantized weights and activations with [Brevitas](https://github.com/Xilinx/brevitas).
Specifically, the task at hand will be to label network packets as normal or suspicious (e.g. originating from an attacker, virus, malware or otherwise) by training on a quantized variant of the UNSW-NB15 dataset. 

**You won't need a GPU to train the neural net.** This MLP will be small enough to train on a modern x86 CPU, so no GPU is required to follow this tutorial  Alternatively, we provide pre-trained parameters for the MLP if you want to skip the training entirely.


## A quick introduction to the task and the dataset

*The task:* The goal of [*network intrusion detection*](https://ieeexplore.ieee.org/abstract/document/283931) is to identify, preferably in real time, unauthorized use, misuse, and abuse of computer systems by both system insiders and external penetrators. This may be achieved by a mix of techniques, and machine-learning (ML) based techniques are increasing in popularity. 

*The dataset:* Several datasets are available for use in ML-based methods for intrusion detection.
The **UNSW-NB15** is one such dataset created by the Australian Centre for Cyber Security (ACCS) to provide a comprehensive network based data set which can reflect modern network traffic scenarios. You can find more details about the dataset on [its homepage](https://www.unsw.adfa.edu.au/unsw-canberra-cyber/cybersecurity/ADFA-NB15-Datasets/).

*Performance considerations:* FPGAs are commonly used for implementing high-performance packet processing systems that still provide a degree of programmability. To avoid introducing bottlenecks on the network, the DNN implementation must be capable of detecting malicious ones at line rate, which can be millions of packets per second, and is expected to increase further as next-generation networking solutions provide increased
throughput. This is a good reason to consider FPGA acceleration for this particular use-case.

## Outline
-------------

* [Load the UNSW_NB15 Dataset](#load_dataset) 
* [Define the Quantized MLP Model](#define_quantized_mlp)
* [Define Train and Test  Methods](#train_test)
    * [(Option 1) Train the Model from Scratch](#train_scratch)
    * [(Option 2) Load Pre-Trained Parameters](#load_pretrained)
* [Network Surgery Before Export](#network_surgery)
* [Export to FINN-ONNX](#export_finn_onnx)

In [1]:
import onnx
import torch

**This is important -- always import onnx before torch**. This is a workaround for a [known bug](https://github.com/onnx/onnx/issues/2394).

## Load the UNSW_NB15 Dataset <a id='load_dataset'></a>

### Dataset Quantization <a id='dataset_qnt'></a>

The goal of this notebook is to train a Quantized Neural Network (QNN) to be later deployed as an FPGA accelerator generated by the FINN compiler. Although we can choose a variety of different precisions for the input, [Murovic and Trost](https://ev.fe.uni-lj.si/1-2-2019/Murovic.pdf) have previously shown we can actually binarize the inputs and still get good (90%+) accuracy.

We will create a binarized representation for the dataset by following the procedure defined by Murovic and Trost, which we repeat briefly here:

* Original features have different formats ranging from integers, floating numbers to strings.
* Integers, which for example represent a packet lifetime, are binarized with as many bits as to include the maximum value. 
* Another case is with features formatted as strings (protocols), which are binarized by simply counting the number of all different strings for each feature and coding them in the appropriate number of bits.
* Floating-point numbers are reformatted into fixed-point representation.
* In the end, each sample is transformed into a 593-bit wide binary vector. 
* All vectors are labeled as bad (0) or normal (1)

Following Murovic and Trost's open-source implementation provided as a Matlab script [here](https://github.com/TadejMurovic/BNN_Deployment/blob/master/cybersecurity_dataset_unswb15.m), we've created a [Python version](dataloader_quantized.py).

<font color="red">**Live FINN tutorial:** Downloading the original dataset and quantizing it can take some time, so we provide a download link to the pre-quantized version for your convenience. </font>

We can extract the binarized numpy arrays from the .npz archive and wrap them as a PyTorch `TensorDataset` as follows:

# Define the Quantized MLP Model <a id='define_quantized_mlp'></a>

We'll now define an MLP model that will be trained to perform inference with quantized weights and activations.
For this, we'll use the quantization-aware training (QAT) capabilities offered by [Brevitas](https://github.com/Xilinx/brevitas).

Our MLP will have four fully-connected (FC) layers in total: three hidden layers with 64 neurons, and a final output layer with a single output, all using 2-bit weights. We'll use 2-bit quantized ReLU activation functions, and apply batch normalization between each FC layer and its activation.

In case you'd like to experiment with different quantization settings or topology parameters, we'll define all these topology settings as variables.

Now we can define our MLP using the layer primitives provided by Brevitas:

In [2]:
from brevitas.nn import QuantLinear, QuantReLU
from brevitas.quant import SignedBinaryWeightPerTensorConst
import torch.nn as nn

#Setting seeds for reproducibility
torch.manual_seed(0)
model = nn.Sequential(
          QuantLinear(40,256,bias=True, weight_bit_width=4),
          nn.BatchNorm1d(256),
#           nn.Dropout(0.2),
          QuantReLU(bit_width=4),
          QuantLinear(256, 128,bias=True, weight_bit_width=4),
          nn.BatchNorm1d(128),
#           nn.Dropout(0.4),
          QuantReLU(bit_width=4),
          QuantLinear(128, 64,bias=True, weight_bit_width=4),
          nn.BatchNorm1d(64),
#           nn.Dropout(0.2),
          QuantReLU(bit_width=4),
          QuantLinear(64, 32,bias=True, weight_bit_width=4),
          nn.BatchNorm1d(32),
#           nn.Dropout(0.5),
          QuantReLU(bit_width=4),
          QuantLinear(32,4,bias=True, weight_bit_width=4),
          nn.Softmax(dim=1)
        )

model = model.float()
# path = './modelModel_95.pt'
# model.load_state_dict(torch.load(path,map_location='cpu'))

In [3]:
import torch

trained_state_dict = torch.load("state_dict.pth")["models_state_dict"][0]

model.load_state_dict(trained_state_dict, strict=False)

RuntimeError: Error(s) in loading state_dict for Sequential:
	size mismatch for 0.weight: copying a param with shape torch.Size([64, 593]) from checkpoint, the shape in current model is torch.Size([256, 40]).
	size mismatch for 0.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for 1.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for 1.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for 1.running_mean: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for 1.running_var: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for 4.weight: copying a param with shape torch.Size([64, 64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for 4.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for 9.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32, 64]).
	size mismatch for 9.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for 12.weight: copying a param with shape torch.Size([1, 64]) from checkpoint, the shape in current model is torch.Size([3, 32]).
	size mismatch for 12.bias: copying a param with shape torch.Size([1]) from checkpoint, the shape in current model is torch.Size([3]).

In [18]:
test(model, test_quantized_loader)

0.9188772287810328

**Why do these parameters give better accuracy vs training from scratch?** Even with the topology and quantization fixed, achieving good accuracy on a given dataset requires [*hyperparameter tuning*](https://towardsdatascience.com/hyperparameters-optimization-526348bb8e2d) and potentially running training for a long time. The "training from scratch" example above is only intended as a quick example, whereas the pretrained parameters are obtained from a longer training run using the [determined.ai](https://determined.ai/) platform for hyperparameter tuning.

# Network Surgery Before Export <a id="network_surgery"></a>

Sometimes, it's desirable to make some changes to our trained network prior to export (this is known in general as "network surgery"). This depends on the model and is not generally necessary, but in this case we want to make a couple of changes to get better results with FINN.

Let's start by padding the input. Our input vectors are 593-bit, which will make folding (parallelization) for the first layer a bit tricky since 593 is a prime number. So we'll pad the weight matrix of the first layer with seven 0-valued columns to work with an input size of 600 instead. When using the modified network we'll similarly provide inputs padded to 600 bits.

In [19]:
from copy import deepcopy

modified_model = deepcopy(model)

W_orig = modified_model[0].weight.data.detach().numpy()
W_orig.shape

(64, 593)

In [20]:
import numpy as np

# pad the second (593-sized) dimensions with 7 zeroes at the end
W_new = np.pad(W_orig, [(0,0), (0,7)])
W_new.shape

(64, 600)

In [21]:
modified_model[0].weight.data = torch.from_numpy(W_new)
modified_model[0].weight.shape

torch.Size([64, 600])

Next, we'll modify the expected input/output ranges. In FINN, we prefer to work with bipolar {-1, +1} instead of binary {0, 1} values. To achieve this, we'll create a "wrapper" model that handles the pre/postprocessing as follows:

* on the input side, we'll pre-process by (x + 1) / 2 in order to map incoming {-1, +1} inputs to {0, 1} ones which the trained network is used to. Since we're just multiplying/adding a scalar, these operations can be [*streamlined*](https://finn.readthedocs.io/en/latest/nw_prep.html#streamlining-transformations) by FINN and implemented with no extra cost.

* on the output side, we'll add a binary quantizer which maps everthing below 0 to -1 and everything above 0 to +1. This is essentially the same behavior as the sigmoid we used earlier, except the outputs are bipolar instead of binary.

In [22]:
from brevitas.core.quant import QuantType
from brevitas.nn import QuantIdentity


class CybSecMLPForExport(nn.Module):
    def __init__(self, my_pretrained_model):
        super(CybSecMLPForExport, self).__init__()
        self.pretrained = my_pretrained_model
        self.qnt_output = QuantIdentity(quant_type=QuantType.BINARY, bit_width=1, min_val=-1.0, max_val=1.0)
    
    def forward(self, x):
        # assume x contains bipolar {-1,1} elems
        # shift from {-1,1} -> {0,1} since that is the
        # input range for the trained network
        x = (x + torch.tensor([1.0])) / 2.0  
        out_original = self.pretrained(x)
        out_final = self.qnt_output(out_original)   # output as {-1,1}     
        return out_final

model_for_export = CybSecMLPForExport(modified_model)

In [23]:
def test_padded_bipolar(model, test_loader):    
    # ensure model is in eval mode
    model.eval() 
    y_true = []
    y_pred = []
   
    with torch.no_grad():
        for data in test_loader:
            inputs, target = data
            # pad inputs to 600 elements
            input_padded = np.pad(inputs, [(0,0), (0,7)])
            # convert inputs to {-1,+1}
            input_scaled = 2*input_padded - 1
            # run the model
            output = model(torch.from_numpy(input_scaled).float())
            y_pred.extend(list(output.flatten()))
            # make targets bipolar {-1,+1}
            expected = 2*target.float() - 1
            expected = expected.detach().numpy()
            y_true.extend(list(expected.flatten()))
        
    return accuracy_score(y_true, y_pred)

In [24]:
test_padded_bipolar(model_for_export, test_quantized_loader)

0.9188772287810328

# Export to FINN-ONNX <a id="export_finn_onnx" ></a>


[ONNX](https://onnx.ai/) is an open format built to represent machine learning models, and the FINN compiler expects an ONNX model as input. We'll now export our network into ONNX to be imported and used in FINN for the next notebooks. Note that the particular ONNX representation used for FINN differs from standard ONNX, you can read more about this [here](https://finn.readthedocs.io/en/latest/internals.html#intermediate-representation-finn-onnx).

You can see below how we export a trained network in Brevitas into a FINN-compatible ONNX representation. Note how we create a `QuantTensor` instance with dummy data to tell Brevitas how our inputs look like, which will be used to set the input quantization annotation on the exported model.

In [4]:
import brevitas.onnx as bo
from brevitas.quant_tensor import QuantTensor
import numpy as np

ready_model_filename = "fpl-4-bit-200Mhz.onnx"
input_shape = (1, 40)
# create a QuantTensor instance to mark input as bipolar during export
input_a = np.random.randint(0, 255, size=input_shape).astype(np.float32)
#input_a = 2 * input_a - 1
scale = 1.0
input_t = torch.from_numpy(input_a * scale)
input_qt = QuantTensor(
    input_t, scale=torch.tensor(scale), bit_width=torch.tensor(8), signed=True
)

#bo.export_finn_onnx(model_for_export, export_path=ready_model_filename, input_t=input_qt)
bo.export_finn_onnx(model, export_path=ready_model_filename, input_t=input_qt)

print("Model saved to %s" % ready_model_filename)

Model saved to fpl-4-bit-200Mhz.onnx


## View the Exported ONNX in Netron

Let's examine the exported ONNX model with [Netron](https://github.com/lutzroeder/netron), which is a visualizer for neural networks and allows interactive investigation of network properties. For example, you can click on the individual nodes and view the properties. Particular things of note:

* The input tensor "0" is annotated with `quantization: finn_datatype: BIPOLAR`
* The input preprocessing (x + 1) / 2 is exported as part of the network (initial `Add` and `Div` layers)
* Brevitas `QuantLinear` layers are exported to ONNX as `MatMul`. We've exported the padded version; shape of the first MatMul node's weight parameter is 600x64
* The weight parameters (second inputs) for MatMul nodes are annotated with `quantization: finn_datatype: INT2`
* The quantized activations are exported as `MultiThreshold` nodes with `domain=finn.custom_op.general`
* There's a final `MultiThreshold` node with threshold=0 to produce the final bipolar output (this is the `qnt_output` from `CybSecMLPForExport`

In [5]:
from finn.util.visualization import showInNetron

showInNetron(ready_model_filename)

Serving 'fpl-4-bit-200Mhz.onnx' at http://0.0.0.0:8081


## That's it! <a id="thats_it" ></a>
You created, trained and tested a quantized MLP that is ready to be loaded into FINN, congratulations! You can now proceed to the next notebook.