# What is model quantization
For most deep learning framework, 32-bit floating point is the dafault data type.
Model quantization mean to represent the real numbers with lower bits, such as 8-bit integers.
In this notebook, we only focus on int8 quantization.

# Why we need model quantization
+ In moving from 32-bits to 8-bits, we can reduce ths size of the model'weight by a quarter.  
  Disk and memory usage will also be reduced.
+ Inference with quantized model is faster than the float32 model. 
+ In some embedded devices or some Iot devices, integer operation is the only choice.

# How to convert NNabla's nnp model to quantized tflite model
In this section, you will learn how to convert NNabla's nnp model to INT8 quantized tflite model.

* Assume that you have installed cuda and cudnn, if not, please install appropriate cuda and cudnn.  
  See the nvidia's installation guide: https://docs.nvidia.com/deeplearning/cudnn/install-guide/index.html

* Install nnabla nnabla-ext-cuda and nnabla-converter package

In [None]:
!pip install nnabla nnabla-ext-cuda110 nnabla-converter

* Build flatbuffers

In [None]:
!git clone https://github.com/google/flatbuffers.git
!cd flatbuffers && cmake -G "Unix Makefiles" && make

* Traing a simple CNN with nnabla

In [None]:
import nnabla as nn
import nnabla.functions as F
import nnabla.parametric_functions as PF

def network(image, test=False):
    batch_state = not test
    h = PF.convolution(image, 24, (3, 3), pad=(1,1), name='conv1')
    h = F.relu(h)
    x = h
    h = PF.convolution(h, 24, (3, 3), pad=(1,1), name='conv2')
    h = PF.batch_normalization(h, batch_stat=batch_state, name='bn1')
    h = PF.convolution(h, 24, (3, 3), pad=(1, 1), name='conv3')
    h = PF.batch_normalization(h, batch_stat=batch_state, name='bn2')

    h = F.add2(x, h)
    h = F.relu(h)

    h = F.max_pooling(h, (2, 2))
    h = PF.convolution(h, 32, (3, 3), pad=(1, 1), name='conv4')
    h = PF.batch_normalization(h, batch_stat=batch_state, name='bn3')
    h = F.relu(h)
    x = h

    h = PF.convolution(h, 32, (3, 3), pad=(1,1), name='conv5')
    h = PF.batch_normalization(h, batch_stat=batch_state, name='bn4')
    h = PF.convolution(h, 32, (3, 3), pad=(1, 1), name='conv6')
    h = PF.batch_normalization(h, batch_stat=batch_state, name='bn5')
    h = F.add2(x, h)
    h = F.relu(h)

    h = F.max_pooling(h, (2, 2))

    c3 = PF.affine(h, 50, name='fc3')
    c3 = F.relu(c3)
    c4 = PF.affine(c3, 10, name='fc4')
    return c4


In [None]:
import numpy
import struct
import zlib

from nnabla.logger import logger
from nnabla.utils.data_iterator import data_iterator
from nnabla.utils.data_source import DataSource
from nnabla.utils.data_source_loader import download


def load_mnist(train=True):
    if train:
        image_uri = 'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz'
        label_uri = 'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz'
    else:
        image_uri = 'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz'
        label_uri = 'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz'
    logger.info('Getting label data from {}.'.format(label_uri))
    r = download(label_uri)
    data = zlib.decompress(r.read(), zlib.MAX_WBITS | 32)
    _, size = struct.unpack('>II', data[0:8])
    labels = numpy.frombuffer(data[8:], numpy.uint8).reshape(-1, 1)
    r.close()
    logger.info('Getting label data done.')

    logger.info('Getting image data from {}.'.format(image_uri))
    r = download(image_uri)
    data = zlib.decompress(r.read(), zlib.MAX_WBITS | 32)
    _, size, height, width = struct.unpack('>IIII', data[0:16])
    images = numpy.frombuffer(data[16:], numpy.uint8).reshape(
        size, 1, height, width)
    r.close()
    logger.info('Getting image data done.')

    return images, labels


class MnistDataSource(DataSource):
    '''
    Get data directly from MNIST dataset from Internet(yann.lecun.com).
    '''

    def _get_data(self, position):
        image = self._images[self._indexes[position]]
        label = self._labels[self._indexes[position]]
        return (image, label)

    def __init__(self, train=True, shuffle=False, rng=None):
        super(MnistDataSource, self).__init__(shuffle=shuffle)
        self._train = train

        self._images, self._labels = load_mnist(train)

        self._size = self._labels.size
        self._variables = ('x', 'y')
        if rng is None:
            rng = numpy.random.RandomState(313)
        self.rng = rng
        self.reset()

    def reset(self):
        if self._shuffle:
            self._indexes = self.rng.permutation(self._size)
        else:
            self._indexes = numpy.arange(self._size)
        super(MnistDataSource, self).reset()

    @property
    def images(self):
        """Get copy of whole data with a shape of (N, 1, H, W)."""
        return self._images.copy()

    @property
    def labels(self):
        """Get copy of whole label with a shape of (N, 1)."""
        return self._labels.copy()


def data_iterator_mnist(batch_size,
                        train=True,
                        rng=None,
                        shuffle=True,
                        with_memory_cache=False,
                        with_file_cache=False):
    return data_iterator(MnistDataSource(train=train, shuffle=shuffle, rng=rng),
                         batch_size,
                         rng,
                         with_memory_cache,
                         with_file_cache)


In [None]:
import nnabla.solver as S
import numpy as np

def get_context(context, gpu_id):
    from nnabla.ext_utils import get_extension_context
    ctx = get_extension_context(context, device_id=gpu_id)
    return ctx

def save_model(x, y, nnpfile):
    from nnabla.utils.save import save
    contents = {
        'networks': [
            {'name': 'mnist',
            'batch_size': 1,
            'outputs': {'pred': y },
            'names': {'image': x }}],
        'executors': [
            {'name': 'runtime',
            'network': 'mnist',
            'data': ['image'],
            'output': ['pred']}]}
    save(nnpfile, contents)
    return

def categorical_error(pred, label):
    pred_label = pred.argmax(1)
    return (pred_label != label.flat).mean()

def train():
    batch_size = 128
    learning_rate = 0.001
    val_interval = 1000
    max_iter = 20000
    val_iter = 20
    weight_decay = 0
    
    image = nn.Variable([batch_size, 1, 28, 28])
    label = nn.Variable([batch_size, 1])
    pred = network(image)
    pred.persistent = True

    loss = F.mean(F.softmax_cross_entropy(pred, label))

    vimage = nn.Variable([batch_size, 1, 28, 28])
    vlabel = nn.Variable([batch_size, 1])
    vpred = network(vimage, test=True)
    solver = S.Adam(learning_rate)
    solver.set_parameters(nn.get_parameters())
    start_point = 0
    from numpy.random import RandomState
    data = data_iterator_mnist(batch_size, True, rng=RandomState(1223))
    vdata = data_iterator_mnist(batch_size, False)
    
    # Training loop.
    for i in range(start_point, max_iter):
        if i % val_interval == 0:
            ve = 0.0
            for j in range(val_iter):
                val_x, val_y = vdata.next()
                val_x = val_x.astype(np.float32) / 255.0
                vimage.d, vlabel.d = val_x, val_y
                vpred.forward(clear_buffer=True)
                vpred.data.cast(np.float32, ctx)
                ve += categorical_error(vpred.d, vlabel.d)
            print("Step:{}  Val Error:{}".format(i, ve/val_iter))

        train_x, train_y = data.next()
        train_x = train_x.astype(np.float32) / 255.0
        image.d, label.d = train_x, train_y

        solver.zero_grad()
        loss.forward(clear_no_need_grad=True)
        loss.backward(clear_buffer=True)
        solver.weight_decay(weight_decay)
        solver.update()
        loss.data.cast(np.float32, ctx)
        pred.data.cast(np.float32, ctx)
        e = categorical_error(pred.d, label.d)

    ve = 0.0
    for j in range(val_iter):
        vimage.d, vlabel.d = data.next()
        vpred.forward(clear_buffer=True)
        ve += categorical_error(vpred.d, vlabel.d)
    
    # Collect represent dataset for quantization
    represent_dataset = []
    for j in range(20): # 20 * 128
        x, _ = data.next()
        represent_dataset.append(x.astype(np.float32) / 255.0)
    represent_dataset = np.concatenate(represent_dataset, axis=0)
    # Save represent dataset and nnp model
    np.save('mnist.npy', represent_dataset)
    save_model(vimage, vpred, 'mnist.nnp')

In [None]:
ctx = get_context('cudnn', '0')
nn.set_default_context(ctx)
train()

* convert nnp model to tflite format

In [None]:
# convert nnp to float32 tflite
!PATH=$PATH:$(pwd)/flatbuffers nnabla_cli convert -b 1 mnist.nnp mnist.tflite
# convert nnp to int8 tflite
!PATH=$PATH:$(pwd)/flatbuffers nnabla_cli convert -b 1 mnist.nnp mnist_int8.tflite --quantization --dataset mnist.npy

# Evaluate the tflite model

In [None]:
import tensorflow as tf
def evaluate_model(tflite, test_images, test_labels):
    interpreter = tf.lite.Interpreter(tflite)
    interpreter.allocate_tensors()
    input_index = interpreter.get_input_details()[0]["index"]
    output_index = interpreter.get_output_details()[0]["index"]
    input_details = interpreter.get_input_details()[0]
    output_details = interpreter.get_output_details()[0]
    prediction_digits = []
    for i, test_image in enumerate(test_images):
        if i % 1000 == 0:
            print('Evaluated on {n} results so far.'.format(n=i))
        test_image = np.expand_dims(test_image, axis=0)
        test_image = np.transpose(test_image, (0,2,3,1))

        # Quantize input
        if input_details['dtype'] == np.int8 or input_details['dtype'] == np.uint8:
            input_scale, input_zero_point = input_details["quantization"]
            test_image = test_image / input_scale + input_zero_point

        test_image = test_image.astype(input_details['dtype'])
        interpreter.set_tensor(input_index, test_image)
        interpreter.invoke()
        output = interpreter.tensor(output_index)()[0]
        output = output.astype(np.float32)
    
        # Dequantize output
        if output_details['dtype'] == np.int8:
            output_scale, output_zero_point = output_details["quantization"]
            output = (output - output_zero_point) * output_scale
        digit = np.argmax(output)
        prediction_digits.append(digit)

    print('\n')
    # Compare prediction results with ground truth labels to calculate accuracy.
    prediction_digits = np.array(prediction_digits)
    accuracy = (prediction_digits == test_labels).mean()
    return accuracy

In [None]:
test_images, test_labels = load_mnist(False)
test_images = test_images.astype(np.float32) / 255.0

# test with float32 tflite
tflite_weight = './mnist.tflite'
acc = evaluate_model(tflite_weight, test_images, test_labels[...,0])
print("accuracy of {} is: {}\n".format(tflite_weight, acc))

# test with int8 tflite
int8_tflite_weight = './mnist_int8.tflite'
acc = evaluate_model(int8_tflite_weight, test_images, test_labels[...,0])
print("accuracy of {} is: {}\n".format(int8_tflite_weight, acc))

In [None]:
!ls -l -h {tflite_weight}
!ls -l -h {int8_tflite_weight}

We can see the accuracy of int8 quantized tflite model and the accuracy of float32 model are very close! But the size of int8 quantized tflite is about a quarter of the float32 model.  
Now, just try this converter with your own model.  
Note that not all of NNabla functions can be converted to quantized tflite op.
You can check this page for more details: https://nnabla.readthedocs.io/en/latest/python/file_format_converter/INT8_TFLite_Support_Status.html