# Converting to TensorFlow Lite using pytorch_to_keras

This notebook shows how to convert mobilenet_v2 from a pytorch model into a quantized TensorFlow Lite model. First, use the python library `pytorch2keras` to convert the model into a Keras model, then follow the usual steps to export from Keras as a quantised int8 tflite model.

Ensure that you have installed Python 3.8 and have the installed `../requirements.txt`

In [None]:
import sys
import os

# allow importing helper functions from local module
module_path = os.path.abspath(os.path.join(".."))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
# The original libraries are unable to convert Relu6 operations
!{sys.executable} -m pip install https://github.com/xmos/onnx2keras/archive/refs/heads/fix/relu6.zip
!{sys.executable} -m pip install https://github.com/xmos/pytorch2keras/archive/refs/heads/fix/relu6.zip

In [None]:
import torch
import io, os, shutil
import tensorflow as tf
import numpy as np
import tflite
from pytorch2keras import pytorch_to_keras
from torch.autograd import Variable

## Import PyTorch Model
For this example, we use mobilenet_v2.

In [None]:
pytorch_model = torch.hub.load(
    "pytorch/vision:v0.10.0", "mobilenet_v2", pretrained=True
)
# Switch the model to eval mode
pytorch_model.eval()

## Run Inference on PyTorch Model

First, lets run inference on the PyTorch model directly, just to see how it works.

In [None]:
# Download an image to test against
import urllib

url, filename = ("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg")
try:
    urllib.URLopener().retrieve(url, filename)
except:
    urllib.request.urlretrieve(url, filename)

import requests

# Download Image Labels
resp = requests.get(
    "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
)
# Read the categories
categories = [s.strip() for s in resp.text.splitlines()]

In [None]:
# We will test and train with these params
batch_size = 1
channels = 3
height = 224
width = 224

In [None]:
from PIL import Image
from torchvision import transforms

# Open testing image
input_image = Image.open(filename)

preprocess = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(height),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

# Note Pytorch is BCHW
input_tensor = preprocess(input_image)

In [None]:
input_batch = input_tensor.unsqueeze(0)  # create a mini-batch as expected by the model

with torch.no_grad():
    output = pytorch_model(input_batch)

probabilities = torch.nn.functional.softmax(output[0], dim=0)

# Show top categories per image
vals, idxs = torch.topk(probabilities, 5)
pytorch_results = [
    (categories[idx], prob) for (idx, prob) in zip(idxs.tolist(), vals.tolist())
]
for cat, prob in pytorch_results:
    print(cat, ":", prob)

## Convert to Keras

In [None]:
def pytorch_to_keras_model(pytorch_model, input_shape) -> tf.keras.Model:
    input_np = np.random.uniform(0, 1, tuple([1]) + input_shape)
    input_var = Variable(torch.FloatTensor(input_np))

    return pytorch_to_keras(
        pytorch_model,
        input_var,
        [input_shape],
        verbose=True,
        name_policy="renumerate",
        change_ordering=True,  # change channel_first to channel_last
    )

In [None]:
keras_model = pytorch_to_keras_model(pytorch_model, input_tensor.shape)

### Check keras conversion

In [None]:
def softmax(xs):
    return np.exp(xs) / sum(np.exp(xs))


# transpose the input_batch into BHWC order for tensorflow
tf_input_data = np.transpose(input_batch.numpy(), [0, 2, 3, 1])

keras_output_data = keras_model(tf_input_data)

probs = keras_output_data[0]
data = zip(range(len(probs)), probs)
keras_results = [
    (categories[idx], prob)
    for (idx, prob) in sorted(data, key=lambda x: x[1], reverse=True)[:5]
]
for cat, prob in keras_results:
    print(cat, ":", prob)

## Convert to tflite

We will still feed the data into the model in float32 format for convinence but the internals of the model will be int8. This will require representitive data but as we interface in float32 we can use the pytorch preprocessing. 

This conversion follows the method from [keras_to_xcore.ipynb](https://colab.research.google.com/github/xmos/ai_tools/blob/develop/docs/notebooks/keras_to_xcore.ipynb)

### Representative Dataset
To convert a model into to a TFLite flatbuffer, a representative dataset is required to help in quantisation. Refer to [Converting a keras model into an xcore optimised tflite model](https://colab.research.google.com/github/xmos/ai_tools/blob/develop/docs/notebooks/keras_to_xcore.ipynb) for more details on this.

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds

ds = (
    tfds.load("imagenette", split="train", as_supervised=True, shuffle_files=True)
    .shuffle(1000)
    .batch(1)
    .prefetch(10)
    .take(1000)
)


# Iterate over the sampled images and preprocess them
def representative_dataset():
    for image, _ in ds:
        pil_img = tf.keras.utils.array_to_img(image[0])
        pytorch_batch = preprocess(pil_img).unsqueeze(0)
        tf_batch = np.transpose(pytorch_batch.numpy(), [0, 2, 3, 1])
        yield [tf_batch]

### Conversion Process

In [None]:
# Now do the conversion to int8
import tensorflow as tf

converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.float32
converter.inference_output_type = tf.float32

tflite_int8_model = converter.convert()

# Save the model.
tflite_int8_model_path = "mobilenet_v2.tflite"
with open(tflite_int8_model_path, "wb") as f:
    f.write(tflite_int8_model)

### Run inference

In [None]:
tfl_interpreter = tf.lite.Interpreter(model_path=tflite_int8_model_path)
tfl_interpreter.allocate_tensors()

tfl_input_details = tfl_interpreter.get_input_details()
tfl_output_details = tfl_interpreter.get_output_details()

# Convert PyTorch Input Tensor into Numpy Matrix and Reshape for TensorFlow
tfl_interpreter.set_tensor(tfl_input_details[0]["index"], tf_input_data)
tfl_interpreter.invoke()

tfl_output_data = tfl_interpreter.get_tensor(tfl_output_details[0]["index"])

probs = softmax(tfl_output_data[0])
data = zip(range(len(probs)), probs)
tfl_int8_results = [
    (categories[idx], prob)
    for (idx, prob) in sorted(data, key=lambda x: x[1], reverse=True)[:5]
]
for cat, prob in tfl_int8_results:
    print(cat, ":", prob)

## Analyse Model

### Check Operator Counts

Lets take a look at the operator counts inside the converted model. This uses a helper function defined in `../utils`, but this step is not necessary to convert the model.

In [None]:
import utils

utils.print_operator_counts(tflite_int8_model)

### Accuracy

Let's compare the accuracy of the converted model to the original PyTorch model.

To do this, we take a large sampel from imagenet_v2 and compare the classifications returned by the models.

In [None]:
import tensorflow_datasets as tfds
import tensorflow as tf
import requests
from typing import List

# load dataset
ds, info = tfds.load(
    "imagenet_v2", split="test", with_info=True, as_supervised=True, shuffle_files=False
)
ds = ds.shuffle(100, reshuffle_each_iteration=True)

In [None]:
def accuracy_tflite(top_n: int = 1, samples=1000, verbose=False) -> float:
    if top_n < 1 or n > 1000:
        raise ValueError

    # take subset of dataset
    selection = ds.prefetch(10).take(samples)

    correct = 0
    incorrect = 0

    for image, label in selection:
        trueCatIdx = tf.get_static_value(label)
        # convert to PIL.Image
        img = tf.keras.utils.array_to_img(image)

        # preprocess using PyTorch functions then convert back into Tf.Tensor
        pytorch_batch = preprocess(img).unsqueeze(0)
        tf_batch = np.transpose(pytorch_batch.numpy(), [0, 2, 3, 1])

        # use same tflite interpreter as before
        tfl_interpreter.set_tensor(tfl_input_details[0]["index"], tf_batch)
        tfl_interpreter.invoke()

        output = tfl_interpreter.get_tensor(tfl_output_details[0]["index"])

        # Sort into List[Tuple[index, confidence]] ordered by confidence (descending)
        data = sorted(
            zip(range(len(output[0])), output[0]), key=lambda x: x[1], reverse=True
        )

        top_n_results: List[int] = [idx for (idx, _) in data[:top_n]]

        if trueCatIdx in top_n_results:
            correct = correct + 1
        else:
            incorrect = incorrect + 1
            if verbose:
                print("--incorrect--")
                print(f"True Category: {categories[trueCatIdx]}({trueCatIdx})")
                print(
                    [
                        f"Top-{top_n} categories: {categories[idx]}({idx})"
                        for idx in top_n_results
                    ]
                )
                display(img)

    accuracy = correct / (correct + incorrect)
    print(
        f"Top-{top_n} accuracy (TFLite Model): {accuracy * 100}% ({correct}/{correct+incorrect})"
    )
    return accuracy

In [None]:
def accuracy_torch(top_n: int = 1, samples=1000, verbose=False):
    if top_n < 1 or n > 1000:
        raise ValueError

    # take subset of dataset
    selection = ds.prefetch(10).take(samples)

    correct = 0
    incorrect = 0

    for image, label in selection:
        trueCatIdx = tf.get_static_value(label)

        # convert to PIL.Image
        img = tf.keras.utils.array_to_img(image)
        input_batch = preprocess(img).unsqueeze(0)

        with torch.no_grad():
            output = pytorch_model(input_batch)

        # Show top categories per image
        vals, idxs = torch.topk(output[0], top_n)

        if trueCatIdx in idxs:
            correct = correct + 1
        else:
            incorrect = incorrect + 1
            if verbose:
                print("--incorrect--")
                print(f"True Category: {categories[trueCatIdx]}({trueCatIdx})")
                print(
                    [
                        f"Top-{top_n} categories: {categories[idx]}({idx})"
                        for idx in top_n_results
                    ]
                )
                display(img)

    accuracy = correct / (correct + incorrect)
    print(
        f"Top-{top_n} accuracy (PyTorch Model): {accuracy * 100}% ({correct}/{correct+incorrect})"
    )
    return accuracy

In [None]:
samples = 500
for n in range(5):
    accuracy_torch(n + 1, samples)
    accuracy_tflite(n + 1, samples)