<a href="https://colab.research.google.com/github/sayakpaul/Adventures-in-TensorFlow-Lite/blob/master/TUNIT_Conversion_to_TF_Lite.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This notebook presents a demo of the TUNIT paper ([Rethinking the Truly Unsupervised Image-to-Image Translation](https://arxiv.org/abs/2006.06500)). GitHub repo of the paper can be found [here](https://github.com/clovaai/tunit). It also demonstrates the process of converting PyTorch models to TF Lite using ONNX. 

![](https://github.com/clovaai/tunit/raw/master/resrc/teaser_3row.png)

<center>Source: Original Paper</center>

Note that the predictions from the converted TF Lite models look faulty. But this notebook still might serve as a reference for the conversion worflow. 

## Setup

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Please note that we used the **animalFaces10_1_00** pre-trained checkpoints. I first copied the files from [here](https://drive.google.com/drive/folders/1rU1B9OLQjYMBzU6VQX7UwLxod2WzOfNz?usp=sharing) to my personal Drive, created a folder called animalFaces10_0_00, and copied the files to that folder.

In [None]:
!git clone https://github.com/clovaai/tunit/

In [None]:
%cd tunit

In [None]:
import torch
from models.generator import Generator
from models.guidingNet import GuidingNet
import torch.nn.functional as F
import torchvision.utils as vutils
from PIL import Image
from torchvision.transforms import ToTensor

## Instantiating the model classes

In [None]:
G = Generator(128, 128)
C = GuidingNet(128)

In [None]:
!cp -r /content/drive/My\ Drive/animalFaces10_1_00 .

In [None]:
!ls -lh animalFaces10_1_00

## Loading the checkpoints in the model classes

In [None]:
load_file = 'animalFaces10_1_00/model_4568.ckpt'
checkpoint = torch.load(load_file, map_location='cpu')
G.load_state_dict(checkpoint['G_EMA_state_dict'])
C.load_state_dict(checkpoint['C_EMA_state_dict'])

The reference image must be an image from a domain included in the training. 

## Gather images for running inference

In [None]:
!wget -O source.jpg https://github.com/NVlabs/FUNIT/raw/master/images/input_content.jpg
!wget -O reference.jpg https://user-images.githubusercontent.com/23406491/84877309-4e7abf80-b0c3-11ea-8f2d-b18d398e9584.jpg

## Prepare the images and then run inference

In [None]:
G.eval()
C.eval()

source_image = Image.open('source.jpg')
reference_image = Image.open('reference.jpg')

x_src = ToTensor()(source_image).unsqueeze(0)
x_ref = ToTensor()(reference_image).unsqueeze(0)

x_src = F.interpolate(x_src, size=(128, 128))
x_ref = F.interpolate(x_ref, size=(128, 128))

x_src = (x_src - 0.5) / 0.5
x_ref = (x_ref - 0.5) / 0.5

s_ref = C.moco(x_ref)
x_res = G(x_src, s_ref)

vutils.save_image(x_res, 'test_out.jpg', normalize=True, padding=0)

## Visualization

In [None]:
import matplotlib.pyplot as plt
import numpy as np

In [None]:
def imshow(pil_image, title=None):
    np_array = np.asarray(pil_image)
    plt.imshow(np_array)
    if title:
        plt.title(title)

plt.figure(figsize=(10, 10))

plt.subplot(1, 3, 1)
imshow(source_image, 'Source Image')

plt.subplot(1, 3, 2)
imshow(reference_image, 'Reference Image')

plt.subplot(1, 3, 3)
result = Image.open('test_out.jpg')
imshow(result, 'Transformed Image')

## Set up `onnx-tf`

Reference: https://towardsdatascience.com/onnx-made-easy-957e60d16e94/ 

In [None]:
%cd /content/tunit

In [None]:
%tensorflow_version 2.x
import tensorflow as tf 
print(tf.__version__)

In [None]:
!pip install -q tensorflow-addons

In [None]:
!git clone https://github.com/onnx/onnx-tensorflow.git

In [None]:
!pip install onnx

In [None]:
%cd onnx-tensorflow

In [None]:
pip install -e .

## Convert to TensorFlow graph

In [None]:
import onnx
from onnx_tf.backend import prepare

In [None]:
# Export the generator
torch.onnx.export(G, (x_src, s_ref), 'generator.onnx', input_names=['test_input', 'style_input'], output_names=['test_output'])

In [None]:
# Export the encoder
torch.onnx.export(C, x_ref, 'encoder.onnx', input_names=['test_input'], output_names=['test_output'])

In [None]:
def generate_tf_graph(onnx_file, tf_graph_file):
    # Load ONNX model and convert to TensorFlow format
    onnx_module = onnx.load(onnx_file)
    tf_rep = prepare(onnx_module)

    # Export model as .pb file
    tf_rep.export_graph(tf_graph_file)

In [None]:
# Generate the TF Graph of the generator onnx module
generate_tf_graph('generator.onnx', 'generator.pb')

In [None]:
# Generate the TF Graph of the encoder onnx module
generate_tf_graph('encoder.onnx', 'encoder.pb')

Ignore the warnings. 

## Inspect the TF graphs

In [None]:
def load_pb(path_to_pb):
    with tf.io.gfile.GFile(path_to_pb, 'rb') as f:
        graph_def = tf.compat.v1.GraphDef()
        graph_def.ParseFromString(f.read())
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name='')
        return graph

def show_ops(path_to_pb):
    tf_graph = load_pb(path_to_pb)

    for op in tf_graph.get_operations():
        print(op.values())

In [None]:
show_ops('generator.pb')

Output to note: `(<tf.Tensor 'test_output:0' shape=(1, 3, 128, 128) dtype=float32>,)`. 

In [None]:
show_ops('encoder.pb')

Output to note: `(<tf.Tensor 'test_output:0' shape=(1, 128) dtype=float32>,)`. It also matches with the dimensions of `s_ref` which is the output we got when we ran `C.moco(x_ref)`. 

## Convert to TF Lite

In [None]:
# During writing this tutorial the Flex ops were only supported via tf-nightly in the Python interpreter
!pip uninstall -q tensorflow
!pip install -q tf-nightly

Restart the runtime at this point. 

In [None]:
import os
import tempfile
import tensorflow as tf 
print(tf.__version__)

In [None]:
def convert_to_tflite(tf_graph, input_arrays):
    converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(
        graph_def_file=tf_graph, 
        input_arrays=input_arrays,
        output_arrays=['test_output']
    )

    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS,
                                        tf.lite.OpsSet.SELECT_TF_OPS]

    # Convert to TFLite Model
    tflite_model = converter.convert()

    _, tflite_path = tempfile.mkstemp('.tflite')
    tflite_model_size = open(tflite_path, 'wb').write(tflite_model)
    tf_model_size = os.path.getsize(tf_graph)
    print('TensorFlow Model is  {} bytes'.format(tf_model_size))
    print('TFLite Model is      {} bytes'.format(tflite_model_size))
    print('Post training dynamic range quantization saves {} bytes'.format(tf_model_size-tflite_model_size))
    print('Saved TF Lite model to: {}'.format(tflite_path))

In [None]:
convert_to_tflite('/content/tunit/onnx-tensorflow/generator.pb', ['test_input', 'style_input'])

In [None]:
convert_to_tflite('/content/tunit/onnx-tensorflow/encoder.pb', ['test_input'])

In [None]:
# Please update the TF Lite file paths from the above before running this cell
!cp /tmp/tmpbhh77i06.tflite generator.tflite
!cp /tmp/tmp6144a9n1.tflite encoder.tflite

## Download the TF Lite files (optional)

In [None]:
# Download the TF Lite files 
from google.colab import files
files.download('generator.tflite')
files.download('encoder.tflite')

## Running inference with TF Lite 

### Inspect the model inputs

In [None]:
def load_tflite_model(tflite_model_path):
    # Load the model.
    interpreter = tf.lite.Interpreter(model_path=tflite_model_path)

    # Set model input.
    input_details = interpreter.get_input_details()
    interpreter.allocate_tensors()

    # Get image size - converting from BCHW to WH
    input_size = input_details[0]['shape'][3], input_details[0]['shape'][2]
    print('Input size of {} model: {}'.format(tflite_model_path, input_size))

    if tflite_model_path == 'generator.tflite':
        style_reference_size = input_details[1]['shape'][0], input_details[1]['shape'][1]
        print('Style reference size of {} model: {}'.format(tflite_model_path, style_reference_size))

    return (interpreter, input_size)

In [None]:
# Load the TF Lite models in the Python TF Lite interpreter
generator_inter, _ = load_tflite_model('generator.tflite')
encoder_inter, _ = load_tflite_model('encoder.tflite')

### Prepare the images for inference

In [None]:
# Copy over the sample images to perform inference
!cp /content/tunit/source.jpg .
!cp /content/tunit/reference.jpg .

In [None]:
# Utility to prepare the images
# We need to match the steps that were performed above
def load_img(path_to_img, reshape_size=(128, 128)):
    img = tf.io.read_file(path_to_img)
    img = tf.io.decode_image(img, channels=3)
    img = tf.image.convert_image_dtype(img, tf.float32)
    img = img[tf.newaxis, :]

    img = tf.image.resize(img, reshape_size, method='nearest')
    img = (img - 0.5) / 0.5
    
    return img

In [None]:
# Prepare the images
x_src = load_img('source.jpg')
x_ref = load_img('reference.jpg')

In [None]:
x_src.shape, x_ref.shape

In [None]:
# The TF Lite models have an input shape of (1, 3, 128, 128)
x_src_reshaped = tf.reshape(x_src, (1, 3, 128, 128))
x_ref_reshaped = tf.reshape(x_ref, (1, 3, 128, 128))

### Run inference

In [None]:
# Function to run style prediction on preprocessed style image.
# Reference: https://www.tensorflow.org/lite/models/style_transfer/overview#style_transform
def run_style_predict(reference_img, tflite_path):
    # Load the model.
    interpreter = tf.lite.Interpreter(model_path=tflite_path)

    # Set model input.
    interpreter.allocate_tensors()
    input_details = interpreter.get_input_details()
    interpreter.set_tensor(input_details[0]["index"], reference_img)

    # Calculate style bottleneck.
    interpreter.invoke()
    style_bottleneck = interpreter.tensor(
        interpreter.get_output_details()[0]["index"]
        )()

    return style_bottleneck

# Calculate style bottleneck for the preprocessed style image.
style_bottleneck = run_style_predict(x_ref_reshaped, 'encoder.tflite')
print('Style Bottleneck Shape:', style_bottleneck.shape)

In [None]:
# Run style transform on preprocessed style image
# Reference: https://www.tensorflow.org/lite/models/style_transfer/overview#style_transform
def run_style_transform(style_bottleneck, preprocessed_source_image, tflite_path):
  # Load the model.
  interpreter = tf.lite.Interpreter(model_path=tflite_path)

  # Set model input.
  input_details = interpreter.get_input_details()
  interpreter.allocate_tensors()

  # Set model inputs.
  interpreter.set_tensor(input_details[0]["index"], preprocessed_source_image)
  interpreter.set_tensor(input_details[1]["index"], style_bottleneck)
  interpreter.invoke()

  # Transform content image.
  stylized_image = interpreter.tensor(
      interpreter.get_output_details()[0]["index"]
      )()

  return stylized_image

# Transform the content image using the style bottleneck.
resultant_image = run_style_transform(style_bottleneck, x_src_reshaped, 'generator.tflite')
print('Resultant image shape:', resultant_image.shape)

In [None]:
# Visualize the resultant image
import matplotlib.pyplot as plt

resultant_image = tf.reshape(resultant_image, (1, 128, 128, 3))
resultant_image = tf.squeeze(resultant_image)
plt.imshow(tf.clip_by_value(resultant_image, 0., 1.))
plt.show()

## Acknowledgements

Thanks to [Kyungjune Baek](https://friedronaldo.github.io/hibkj/) for his guidance in running demo inference in PyTorch. 