##### Copyright 2021 The IREE Authors

In [10]:
#@title Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

# IREE TensorFlow Hub Import

This notebook demonstrates how to download, import, and compile models from [TensorFlow Hub](https://tfhub.dev/). It covers:

* Downloading a model from TensorFlow Hub
* Ensuring the model has serving signatures needed for import
* Importing and compiling the model with IREE

At the end of the notebook, the compilation artifacts are compressed into a .zip file for you to download and use in an application.

See also https://iree.dev/guides/ml-frameworks/tensorflow/.

## Setup

In [11]:
%%capture
!python -m pip install iree-compiler iree-runtime iree-tools-tf -f https://iree.dev/pip-release-links.html

In [12]:
import os
import tensorflow as tf
import tensorflow_hub as hub
import tempfile
from IPython.display import clear_output

from iree.compiler import tf as tfc

# Print version information for future notebook users to reference.
print("TensorFlow version: ", tf.__version__)

ARTIFACTS_DIR = os.path.join(tempfile.gettempdir(), "iree", "colab_artifacts")
os.makedirs(ARTIFACTS_DIR, exist_ok=True)
print(f"Using artifacts directory '{ARTIFACTS_DIR}'")

TensorFlow version:  2.12.0
Using artifacts directory '/tmp/iree/colab_artifacts'


## Import pretrained [`mobilenet_v2`](https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/4) model

IREE supports importing TensorFlow 2 models exported in the [SavedModel](https://www.tensorflow.org/guide/saved_model) format. This model we'll be importing is published in that format already, while other models may need to be converted first.

MobileNet V2 is a family of neural network architectures for efficient on-device image classification and related tasks. This TensorFlow Hub module contains a trained instance of one particular network architecture packaged to perform image classification.

In [13]:
#@title Download the pretrained model

# Use the `hub` library to download the pretrained model to the local disk
# https://www.tensorflow.org/hub/api_docs/python/hub
HUB_PATH = "https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/4"
model_path = hub.resolve(HUB_PATH)
print(f"Downloaded model from tfhub to path: '{model_path}'")

Downloaded model from tfhub to path: '/tmp/tfhub_modules/426589ad685896ab7954855255a52db3442cb38d'


### Check for serving signatures and re-export as needed

IREE's compiler tools, like TensorFlow's `saved_model_cli` and other tools, require "serving signatures" to be defined in SavedModels.

More references:

* https://www.tensorflow.org/tfx/serving/signature_defs
* https://blog.tensorflow.org/2021/03/a-tour-of-savedmodel-signatures.html

In [14]:
#@title Check for serving signatures

# Load the SavedModel from the local disk and check if it has serving signatures
# https://www.tensorflow.org/guide/saved_model#loading_and_using_a_custom_model
loaded_model = tf.saved_model.load(model_path)
serving_signatures = list(loaded_model.signatures.keys())
print(f"Loaded SavedModel from '{model_path}'")
print(f"Serving signatures: {serving_signatures}")

# Also check with the saved_model_cli:
print("\n---\n")
print("Checking for signature_defs using saved_model_cli:\n")
!saved_model_cli show --dir {model_path} --tag_set serve --signature_def serving_default

Loaded SavedModel from '/tmp/tfhub_modules/426589ad685896ab7954855255a52db3442cb38d'
Serving signatures: []

---

Checking for signature_defs using saved_model_cli:

Traceback (most recent call last):
  File "/usr/local/bin/saved_model_cli", line 8, in <module>
    sys.exit(main())
  File "/usr/local/lib/python3.9/dist-packages/tensorflow/python/tools/saved_model_cli.py", line 1284, in main
    app.run(smcli_main)
  File "/usr/local/lib/python3.9/dist-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/usr/local/lib/python3.9/dist-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/usr/local/lib/python3.9/dist-packages/tensorflow/python/tools/saved_model_cli.py", line 1282, in smcli_main
    args.func()
  File "/usr/local/lib/python3.9/dist-packages/tensorflow/python/tools/saved_model_cli.py", line 961, in show
    _show_inputs_outputs(
  File "/usr/local/lib/python3.9/dist-packages/tensorflow/python/tools/saved_model_cli.py", line 345

Since the model we downloaded did not include any serving signatures, we'll re-export it with serving signatures defined.

* https://www.tensorflow.org/guide/saved_model#specifying_signatures_during_export

In [15]:
#@title Look up input signatures to use when exporting

# To save serving signatures we need to specify a `ConcreteFunction` with a
# TensorSpec signature. We can determine what this signature should be by
# looking at any documentation for the model or running the saved_model_cli.

!saved_model_cli show --dir {model_path} --all \
    2> /dev/null | grep "inputs: TensorSpec" | tail -n 1

          inputs: TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name='inputs')


In [16]:
#@title Re-export the model using the known signature

# Get a concrete function using the signature we found above.
# 
# The first element of the shape is a dynamic batch size. We'll be running
# inference on a single image at a time, so set it to `1`. The rest of the
# shape is the fixed image dimensions [width=224, height=224, channels=3].
call = loaded_model.__call__.get_concrete_function(tf.TensorSpec([1, 224, 224, 3], tf.float32))

# Save the model, setting the concrete function as a serving signature.
# https://www.tensorflow.org/guide/saved_model#saving_a_custom_model
resaved_model_path = '/tmp/resaved_model'
tf.saved_model.save(loaded_model, resaved_model_path, signatures=call)
clear_output()  # Skip over TensorFlow's output.
print(f"Saved model with serving signatures to '{resaved_model_path}'")

# Load the model back into memory and check that it has serving signatures now
reloaded_model = tf.saved_model.load(resaved_model_path)
reloaded_serving_signatures = list(reloaded_model.signatures.keys())
print(f"\nReloaded SavedModel from '{resaved_model_path}'")
print(f"Serving signatures: {reloaded_serving_signatures}")

# Also check with the saved_model_cli:
print("\n---\n")
print("Checking for signature_defs using saved_model_cli:\n")
!saved_model_cli show --dir {resaved_model_path} --tag_set serve --signature_def serving_default

Saved model with serving signatures to '/tmp/resaved_model'

Reloaded SavedModel from '/tmp/resaved_model'
Serving signatures: ['serving_default']

---

Checking for signature_defs using saved_model_cli:

The given SavedModel SignatureDef contains the following input(s):
  inputs['inputs'] tensor_info:
      dtype: DT_FLOAT
      shape: (1, 224, 224, 3)
      name: serving_default_inputs:0
The given SavedModel SignatureDef contains the following output(s):
  outputs['output_0'] tensor_info:
      dtype: DT_FLOAT
      shape: (1, 1001)
      name: StatefulPartitionedCall:0
Method name is: tensorflow/serving/predict


### Import and compile the SavedModel with IREE

In [17]:
#@title Import from SavedModel

# The main output file from compilation is a .vmfb "VM FlatBuffer". This file
# can used to run the compiled model with IREE's runtime.
output_file = os.path.join(ARTIFACTS_DIR, "mobilenet_v2.vmfb")
# As compilation runs, dump an intermediate .mlir file for future inspection.
iree_input = os.path.join(ARTIFACTS_DIR, "mobilenet_v2_iree_input.mlir")

# Since our SavedModel uses signature defs, we use `saved_model_tags` with
# `import_type="SIGNATURE_DEF"`. If the SavedModel used an object graph, we
# would use `exported_names` with `import_type="OBJECT_GRAPH"` instead.

# We'll set `target_backends=["vmvx"]` to use IREE's reference CPU backend.
# We could instead use different backends here, or set `import_only=True` then
# download the imported .mlir file for compilation using native tools directly.

tfc.compile_saved_model(
    resaved_model_path,
    output_file=output_file,
    save_temp_iree_input=iree_input,
    import_type="SIGNATURE_DEF",
    saved_model_tags=set(["serve"]),
    target_backends=["vmvx"])
clear_output()  # Skip over TensorFlow's output.

print(f"Saved compiled output to '{output_file}'")
print(f"Saved iree_input to      '{iree_input}'")

Saved compiled output to '/tmp/iree/colab_artifacts/mobilenet_v2.vmfb'
Saved iree_input to      '/tmp/iree/colab_artifacts/mobilenet_v2_iree_input.mlir'


In [18]:
#@title Download compilation artifacts

ARTIFACTS_ZIP = "/tmp/mobilenet_colab_artifacts.zip"

print(f"Zipping '{ARTIFACTS_DIR}' to '{ARTIFACTS_ZIP}' for download...")
!cd {ARTIFACTS_DIR} && zip -r {ARTIFACTS_ZIP} .

# Note: you can also download files using the file explorer on the left
try:
    from google.colab import files
    print("Downloading the artifacts zip file...")
    files.download(ARTIFACTS_ZIP)
except ImportError:
    print("Missing google_colab Python package, can't download files")

Zipping '/tmp/iree/colab_artifacts' to '/tmp/mobilenet_colab_artifacts.zip' for download...
  adding: mobilenet_v2.vmfb (deflated 8%)
  adding: mobilenet_v2_iree_input.mlir (deflated 46%)
Downloading the artifacts zip file...


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>