<a href="https://colab.research.google.com/github/sayakpaul/ViT-jax2tf/blob/main/conversion.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## References

* https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/README.md
* https://github.com/google-research/vision_transformer/blob/main/vit_jax.ipynb

## Setup

In [None]:
!pip install -q absl-py>=0.12.0 chex>=0.0.7 clu>=0.0.3 einops>=0.3.0
!pip install -q flax==0.3.3 ml-collections==0.1.0 tf-nightly
!pip install -q numpy>=1.19.5 pandas>=1.1.0

In [None]:
# Clone repository and pull latest changes.
![ -d vision_transformer ] || git clone --depth=1 https://github.com/google-research/vision_transformer
!cd vision_transformer && git pull

## Imports

In [None]:
import sys

if "./vision_transformer" not in sys.path:
    sys.path.append("./vision_transformer")

from vit_jax import models
from vit_jax import checkpoint
from vit_jax.configs import common as common_config
from vit_jax.configs import models as models_config

from jax.experimental import jax2tf
import tensorflow as tf
import flax
import jax

from PIL import Image
from io import BytesIO
import numpy as np
import requests

In [None]:
print(f"JAX version: {jax.__version__}")
print(f"FLAX version: {flax.__version__}")
print(f"TensorFlow version: {tf.__version__}")

## Classification / Feature Extractor model

In [None]:
#@title Choose a model type
VIT_MODELS = "B_8-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224" #@param ["L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224", "B_16-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224", "R50_L_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224", "R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.03-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.03-res_224", "R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224", "S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224", "B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224", "B_8-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224"]
#@markdown The models were selected based on the criteria shown here in [this notebook](https://github.com/sayakpaul/ViT-jax2tf/blob/main/model-selector.ipynb).

print(f"Model type selected: ViT-{VIT_MODELS.split('-')[0]}")

ROOT_GCS_PATH = "gs://vit_models/augreg"

In [None]:
classification_model = True

if classification_model:
    num_classes = 1000
    print("Will be converting a classification model.")
else:
    num_classes = None
    print("Will be converting a feature extraction model.")

In [None]:
# Instantiate model class and load the corresponding checkpoints.
config = common_config.get_config()
config.model = models_config.AUGREG_CONFIGS[f"{VIT_MODELS.split('-')[0]}"]

model = models.VisionTransformer(num_classes=num_classes, **config.model)

path = f"{ROOT_GCS_PATH}/{VIT_MODELS}.npz"
params = checkpoint.load(path)

if not num_classes:
    _ = params.pop("head")

## Conversion

Code has been reused from the official examples [here](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/README.md).

### Step 1: Get a prediction function out of the JAX model & convert it to a native TF function

In [None]:
predict_fn = lambda params, inputs: model.apply(
    dict(params=params), inputs, train=False
)

with_gradient = False if num_classes else True
tf_fn = jax2tf.convert(
    predict_fn,
    with_gradient=with_gradient,
    polymorphic_shapes=[None, "b, 224, 224, 3"],
)

We set `polymorphic_shapes` to allow the converted model operate with arbitrary batch sizes. Know more about the shape polymorphism in JAX from [here](https://github.com/google/jax/tree/main/jax/experimental/jax2tf#shape-polymorphic-conversion).

### Step 2: Set the trainability of the individual param groups and construct TF graph

In [None]:
param_vars = tf.nest.map_structure(
    lambda param: tf.Variable(param, trainable=with_gradient), params
)
tf_graph = tf.function(
    lambda inputs: tf_fn(param_vars, inputs), autograph=False, jit_compile=True
)

### Step 3: Serialize as a SavedModel

In [None]:
#@title SavedModel wrapper class utility from [here](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/saved_model_lib.py#L128)
class _ReusableSavedModelWrapper(tf.train.Checkpoint):
  """Wraps a function and its parameters for saving to a SavedModel.
  Implements the interface described at
  https://www.tensorflow.org/hub/reusable_saved_models.
  """

  def __init__(self, tf_graph, param_vars):
    """Args:
      tf_graph: a tf.function taking one argument (the inputs), which can be
         be tuples/lists/dictionaries of np.ndarray or tensors. The function
         may have references to the tf.Variables in `param_vars`.
      param_vars: the parameters, as tuples/lists/dictionaries of tf.Variable,
         to be saved as the variables of the SavedModel.
    """
    super().__init__()
    # Implement the interface from https://www.tensorflow.org/hub/reusable_saved_models
    self.variables = tf.nest.flatten(param_vars)
    self.trainable_variables = [v for v in self.variables if v.trainable]
    # If you intend to prescribe regularization terms for users of the model,
    # add them as @tf.functions with no inputs to this list. Else drop this.
    self.regularization_losses = []
    self.__call__ = tf_graph


In [None]:
input_signatures = [tf.TensorSpec(shape=[None, 224, 224, 3], dtype=tf.float32)]
model_dir = VIT_MODELS if num_classes else f"{VIT_MODELS}_fe"
signatures = {}
saved_model_options = None

print(f"Saving model to {model_dir} directory.")

In [None]:
signatures[
    tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
] = tf_graph.get_concrete_function(input_signatures[0])

wrapper = _ReusableSavedModelWrapper(tf_graph, param_vars)
if with_gradient:
    if not saved_model_options:
        saved_model_options = tf.saved_model.SaveOptions(
            experimental_custom_gradients=True
        )
    else:
        saved_model_options.experimental_custom_gradients = True
tf.saved_model.save(
    wrapper, model_dir, signatures=signatures, options=saved_model_options
)

# Note that directly saving the `wrapper` to a GCS location is
# also supported.

## Functional test (credits: [Willi Gierke](https://ch.linkedin.com/in/willi-gierke))

### Image preprocessing utilities 

In [None]:
def preprocess_image(image):
    image = np.array(image)
    image_resized = tf.image.resize(image, (224, 224))
    image_resized = tf.cast(image_resized, tf.float32)
    image_resized = (image_resized - 127.5) / 127.5
    return tf.expand_dims(image_resized, 0).numpy()

def load_image_from_url(url):
    response = requests.get(url)
    image = Image.open(BytesIO(response.content))
    image = preprocess_image(image)
    return image

!wget https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt -O ilsvrc2012_wordnet_lemmas.txt

### Load image and ImageNet-1k class mappings

In [None]:
with open("ilsvrc2012_wordnet_lemmas.txt", "r") as f:
    lines = f.readlines()
imagenet_int_to_str = [line.rstrip() for line in lines]

img_url = "https://p0.pikrepo.com/preview/853/907/close-up-photo-of-gray-elephant.jpg"
image = load_image_from_url(img_url)

### Inference

This is only application for the classification models. For fine-tuning/feature extraction, please follow [this notebook](https://colab.research.google.com/github/sayakpaul/ViT-jax2tf/blob/main/fine_tune.ipynb) instead.

In [None]:
# Load the converted SavedModel and check whether it finds the elephant.
restored_model = tf.saved_model.load(model_dir)
predictions = restored_model.signatures["serving_default"](tf.constant(image))
logits = predictions["output_0"][0]
predicted_label = imagenet_int_to_str[int(np.argmax(logits))]
expected_label = "Indian_elephant, Elephas_maximus"
assert (
    predicted_label == expected_label
), f"Expected {expected_label} but was {predicted_label}"

## Inference with TensorFlow Hub 

Run the following code snippet. You can also follow [this notebook](https://colab.research.google.com/github/sayakpaul/ViT-jax2tf/blob/main/classification.ipynb). 

```python
import tensorflow_hub as hub

classification_model = tf.keras.Sequential([hub.KerasLayer(model_dir)])
predictions = classification_model.predict(image)
predicted_label = imagenet_int_to_str[int(np.argmax(predictions))]
predicted_label
```