<a href="https://colab.research.google.com/github/sayakpaul/MLPMixer-jax2tf/blob/main/conversion.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Setup

In [None]:
!pip install -q absl-py>=0.12.0 chex>=0.0.7 clu>=0.0.3 einops>=0.3.0
!pip install -q flax==0.3.3 ml-collections==0.1.0 tf-nightly
!pip install -q numpy>=1.19.5 pandas>=1.1.0

In [None]:
# Clone repository and pull latest changes.
![ -d vision_transformer ] || git clone --depth=1 https://github.com/sayakpaul/vision_transformer -b mixer-b32
!cd vision_transformer && git pull

## Imports

In [None]:
import sys

if "./vision_transformer" not in sys.path:
    sys.path.append("./vision_transformer")

from vit_jax import models
from vit_jax import checkpoint
from vit_jax.configs import common as common_config
from vit_jax.configs import models as models_config

from jax.experimental import jax2tf
import tensorflow as tf
import flax
import jax

from PIL import Image
from io import BytesIO
import numpy as np
import requests

In [None]:
print(f"JAX version: {jax.__version__}")
print(f"FLAX version: {flax.__version__}")
print(f"TensorFlow version: {tf.__version__}")

## Select model

In [None]:
#@title Choose a model type
MIXER_MODELS = "B_32" #@param ["L_16", "B_16", "B_32"]
DATASET = "imagenet1k" #@param ["imagenet1k", "imagenet21k"]
SAM_PRETRAINED = True #@param {type:"boolean"}

if SAM_PRETRAINED and (MIXER_MODELS == "L_16" or DATASET == "imagenet-21k"):
    raise ValueError(f"{MIXER_MODELS} and {DATASET} checkpoints are not available for SAM pre-training.") 
elif not SAM_PRETRAINED and MIXER_MODELS == "B_32":
    raise ValueError(f"{MIXER_MODELS} is only available with SAM.")
else:
    print(f"Model type selected: Mixer-{MIXER_MODELS}")
    print(f"Dataset selected: {DATASET}")

ROOT_GCS_PATH = "gs://mixer_models/"

In [None]:
classification_model = True
num_classes_map = {
    "imagenet1k": 1000,
    "imagenet21k": 21843
}

if classification_model:
    if not SAM_PRETRAINED:
        num_classes = num_classes_map[DATASET]
    else:
        num_classes = 1000
    print(f"Will be converting a classification model with {num_classes} classes.")
else:
    num_classes = None
    print("Will be converting a feature extraction model.")

## Instantiate model class and load checkpoints

In [None]:
# Instantiate model class and load the corresponding checkpoints.
model_config = models_config.MODEL_CONFIGS[f"Mixer-{MIXER_MODELS}"]
model = models.MlpMixer(num_classes=num_classes, **model_config)

if SAM_PRETRAINED:
    path = f"{ROOT_GCS_PATH}sam/Mixer-{MIXER_MODELS}.npz"
else:
    path = f"{ROOT_GCS_PATH}{DATASET}/Mixer-{MIXER_MODELS}.npz"

params = checkpoint.load(path)

if not num_classes:
    _ = params.pop("head")

## Run conversion

Code has been reused from the official examples [here](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/README.md).

In [None]:
predict_fn = lambda params, inputs: model.apply(
    dict(params=params), inputs, train=False
)

with_gradient = False if num_classes else True
tf_fn = jax2tf.convert(
    predict_fn,
    with_gradient=with_gradient,
    polymorphic_shapes=[None, "b, 224, 224, 3"],
    enable_xla=True,
)

In [None]:
trainable = False if num_classes else True
param_vars = tf.nest.map_structure(
    lambda param: tf.Variable(param, trainable=trainable), params
)
tf_graph = tf.function(
    lambda inputs: tf_fn(param_vars, inputs), autograph=False, jit_compile=True
)

In [None]:
#@title SavedModel wrapper class utility from [here](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/saved_model_lib.py#L128)
class _ReusableSavedModelWrapper(tf.train.Checkpoint):
  """Wraps a function and its parameters for saving to a SavedModel.
  Implements the interface described at
  https://www.tensorflow.org/hub/reusable_saved_models.
  """

  def __init__(self, tf_graph, param_vars):
    """Args:
      tf_graph: a tf.function taking one argument (the inputs), which can be
         be tuples/lists/dictionaries of np.ndarray or tensors. The function
         may have references to the tf.Variables in `param_vars`.
      param_vars: the parameters, as tuples/lists/dictionaries of tf.Variable,
         to be saved as the variables of the SavedModel.
    """
    super().__init__()
    # Implement the interface from https://www.tensorflow.org/hub/reusable_saved_models
    self.variables = tf.nest.flatten(param_vars)
    self.trainable_variables = [v for v in self.variables if v.trainable]
    # If you intend to prescribe regularization terms for users of the model,
    # add them as @tf.functions with no inputs to this list. Else drop this.
    self.regularization_losses = []
    self.__call__ = tf_graph


In [None]:
input_signatures = [tf.TensorSpec(shape=[None, 224, 224, 3], dtype=tf.float32)]

if SAM_PRETRAINED:
    model_dir = MIXER_MODELS + "_sam"
else:
    model_dir = MIXER_MODELS + f"_{DATASET}"

model_dir = model_dir if num_classes else f"{model_dir}_fe"

signatures = {}
saved_model_options = None

print(f"Saving model to {model_dir} directory.")

In [None]:
signatures[
    tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
] = tf_graph.get_concrete_function(input_signatures[0])

wrapper = _ReusableSavedModelWrapper(tf_graph, param_vars)
if with_gradient:
    if not saved_model_options:
        saved_model_options = tf.saved_model.SaveOptions(
            experimental_custom_gradients=True
        )
    else:
        saved_model_options.experimental_custom_gradients = True
tf.saved_model.save(
    wrapper, model_dir, signatures=signatures, options=saved_model_options
)

## Functional test (credits: [Willi Gierke](https://ch.linkedin.com/in/willi-gierke))

***Currently only applicable for ImageNet-1k.*** 

### Image preprocessing utilities 

In [None]:
def preprocess_image(image):
    image = np.array(image)
    image_resized = tf.image.resize(image, (224, 224))
    image_resized = tf.cast(image_resized, tf.float32)
    image_resized = (image_resized - 127.5) / 127.5
    return tf.expand_dims(image_resized, 0).numpy()

def load_image_from_url(url):
    response = requests.get(url)
    image = Image.open(BytesIO(response.content))
    image = preprocess_image(image)
    return image

!wget https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt -O ilsvrc2012_wordnet_lemmas.txt

### Load image and ImageNet-1k class mappings

In [None]:
with open("ilsvrc2012_wordnet_lemmas.txt", "r") as f:
    lines = f.readlines()
imagenet_int_to_str = [line.rstrip() for line in lines]

img_url = "https://p0.pikrepo.com/preview/853/907/close-up-photo-of-gray-elephant.jpg"
image = load_image_from_url(img_url)

### Inference

In [None]:
# Load the converted SavedModel and check whether it finds the elephant.
restored_model = tf.saved_model.load(model_dir)
predictions = restored_model.signatures["serving_default"](tf.constant(image))
logits = predictions["output_0"][0]
predicted_label = imagenet_int_to_str[int(np.argmax(logits))]
expected_label = "Indian_elephant, Elephas_maximus"
assert (
    predicted_label == expected_label
), f"Expected {expected_label} but was {predicted_label}"