## Imports

In [None]:
import tensorflow as tf 
import pandas as pd
import numpy as np

from pprint import pformat

In [None]:
import sys

sys.path.append("..")

from vit.configs import base_config
from vit.layers import mha
from vit.models import ViTClassifierExtended

## Select the master dataframe from [AugReg paper](https://arxiv.org/abs/2106.10270)

In [None]:
with tf.io.gfile.GFile("gs://vit_models/augreg/index.csv") as f:
    df = pd.read_csv(f)

df.head()

## Pick a checkpoint

**Criteria**

* B16 architecture
* Resolution 224
* Pacth size 16
* Best top-1 accuracy on ImageNet-1k

In [None]:
b16s = df.query(
    'ds=="i21k" & adapt_resolution==224 & adapt_ds=="imagenet2012" & name=="B/16"'
).sort_values("adapt_final_test", ascending=False)
b16s.head()

In [None]:
best_b16_i1k_checkpoint = str(b16s.iloc[0]["adapt_filename"])
b16s.iloc[0]["adapt_filename"], b16s.iloc[0]["adapt_final_test"]

In [None]:
filename = best_b16_i1k_checkpoint

path = f"gs://vit_models/augreg/{filename}.npz"

print(f"{tf.io.gfile.stat(path).length / 1024 / 1024:.1f} MiB - {path}")

## Copy over the checkpoint and load it

In [None]:
# !gsutil cp {path} .
local_path = path.split("//")[-1].split("/")[-1]
local_path

In [None]:
with open(local_path, "rb") as f:
    params_jax = np.load(f)
    params_jax = dict(zip(params_jax.keys(), params_jax.values()))

# print(pformat(list(params_jax.keys())))

In [None]:
params_jax["Transformer/encoderblock_0/MultiHeadDotProductAttention_1/key/kernel"].shape

## Instantiate a ViT model in TF

In [None]:
config = base_config.get_config()
with config.unlocked():
    config.num_classes = 1000

config.to_dict()

In [None]:
# Make sure it works.
vit_b16_model = ViTClassifierExtended(config)
vit_b16_model(tf.ones((1, 224, 224, 3)))[0].shape

## Copy the projection layer params

In [None]:
# Projection.

vit_b16_model.layers[0].layers[0].kernel.assign(
    tf.Variable(params_jax["embedding/kernel"])
)
vit_b16_model.layers[0].layers[0].bias.assign(tf.Variable(params_jax["embedding/bias"]))
print(" ")

In [None]:
np.testing.assert_allclose(
    vit_b16_model.layers[0].layers[0].kernel.numpy(), params_jax["embedding/kernel"]
)

In [None]:
np.testing.assert_allclose(
    vit_b16_model.layers[0].layers[0].bias.numpy(), params_jax["embedding/bias"]
)

## Copy the positional embeddings

In [None]:
# Positional embedding.

vit_b16_model.positional_embedding.assign(
    tf.Variable(params_jax["Transformer/posembed_input/pos_embedding"])
)
print(" ")

In [None]:
np.testing.assert_allclose(
    vit_b16_model.positional_embedding.numpy(),
    params_jax["Transformer/posembed_input/pos_embedding"],
)

## Copy the `cls_token`

In [None]:
# Cls token.

vit_b16_model.cls_token.assign(tf.Variable(params_jax["cls"]))
print(" ")

In [None]:
np.testing.assert_allclose(vit_b16_model.cls_token.numpy(), params_jax["cls"])

## Copy the final Layer Norm params

In [None]:
# Final layer norm layer.
vit_b16_model.layers[-2].gamma.assign(
    tf.Variable(params_jax["Transformer/encoder_norm/scale"])
)
vit_b16_model.layers[-2].beta.assign(
    tf.Variable(params_jax["Transformer/encoder_norm/bias"])
)

print(" ")

In [None]:
np.testing.assert_allclose(
    vit_b16_model.layers[-2].gamma.numpy(), params_jax["Transformer/encoder_norm/scale"]
)

In [None]:
np.testing.assert_allclose(
    vit_b16_model.layers[-2].beta.numpy(), params_jax["Transformer/encoder_norm/bias"]
)

## Copy head layer params

In [None]:
# Head layer.

vit_b16_model.layers[-1].kernel.assign(tf.Variable(params_jax["head/kernel"]))
vit_b16_model.layers[-1].bias.assign(tf.Variable(params_jax["head/bias"]))
print(" ")

In [None]:
np.testing.assert_allclose(
    vit_b16_model.layers[-1].kernel.numpy(), params_jax["head/kernel"]
)

In [None]:
np.testing.assert_allclose(
    vit_b16_model.layers[-1].bias.numpy(), params_jax["head/bias"]
)

## Copy the Transformer params

**Structure of a single Transformer encoder block in the JAX model**:


```md
 'Transformer/encoderblock_0/LayerNorm_0/bias',
 'Transformer/encoderblock_0/LayerNorm_0/scale',
 'Transformer/encoderblock_0/LayerNorm_2/bias',
 'Transformer/encoderblock_0/LayerNorm_2/scale',
 'Transformer/encoderblock_0/MlpBlock_3/Dense_0/bias',
 'Transformer/encoderblock_0/MlpBlock_3/Dense_0/kernel',
 'Transformer/encoderblock_0/MlpBlock_3/Dense_1/bias',
 'Transformer/encoderblock_0/MlpBlock_3/Dense_1/kernel',
 'Transformer/encoderblock_0/MultiHeadDotProductAttention_1/key/bias',
 'Transformer/encoderblock_0/MultiHeadDotProductAttention_1/key/kernel',
 'Transformer/encoderblock_0/MultiHeadDotProductAttention_1/out/bias',
 'Transformer/encoderblock_0/MultiHeadDotProductAttention_1/out/kernel',
 'Transformer/encoderblock_0/MultiHeadDotProductAttention_1/query/bias',
 'Transformer/encoderblock_0/MultiHeadDotProductAttention_1/query/kernel',
 'Transformer/encoderblock_0/MultiHeadDotProductAttention_1/value/bias',
 'Transformer/encoderblock_0/MultiHeadDotProductAttention_1/value/kernel',
```

In [None]:
def modify_attention_block(tf_component, jax_component, params_jax, config):
    tf_component.kernel.assign(
        tf.Variable(
            params_jax[f"{jax_component}/kernel"].reshape(config.projection_dim, -1)
        )
    )
    tf_component.bias.assign(
        tf.Variable(
            params_jax[f"{jax_component}/bias"].reshape(-1)
        )
    )
    return tf_component

In [None]:
idx = 0
for outer_layer in vit_b16_model.layers:
    if isinstance(outer_layer, tf.keras.Model) and outer_layer.name != "projection":
        tf_block = vit_b16_model.get_layer(outer_layer.name)
        jax_block_name = f"encoderblock_{idx}"

        # LayerNorm layers.
        layer_norm_idx = 0
        for layer in tf_block.layers:
            if isinstance(layer, tf.keras.layers.LayerNormalization):
                layer_norm_jax_prefix = (
                    f"Transformer/{jax_block_name}/LayerNorm_{layer_norm_idx}"
                )
                layer.gamma.assign(
                    tf.Variable(params_jax[f"{layer_norm_jax_prefix}/scale"])
                )
                layer.beta.assign(
                    tf.Variable(params_jax[f"{layer_norm_jax_prefix}/bias"])
                )
                layer_norm_idx += 2

        # FFN layers.
        ffn_layer_idx = 0
        for layer in tf_block.layers:
            if isinstance(layer, tf.keras.layers.Dense):
                dense_layer_jax_prefix = (
                    f"Transformer/{jax_block_name}/MlpBlock_3/Dense_{ffn_layer_idx}"
                )
                layer.kernel.assign(
                    tf.Variable(params_jax[f"{dense_layer_jax_prefix}/kernel"])
                )
                layer.bias.assign(
                    tf.Variable(params_jax[f"{dense_layer_jax_prefix}/bias"])
                )
                ffn_layer_idx += 1

        # Attention layer.
        for layer in tf_block.layers:
            attn_layer_jax_prefix = (
                f"Transformer/{jax_block_name}/MultiHeadDotProductAttention_1"
            )
            if isinstance(layer, mha.TFViTAttention):
                # Key
                layer.self_attention.key = modify_attention_block(
                    layer.self_attention.key,
                    f"{attn_layer_jax_prefix}/key",
                    params_jax,
                    config,
                )
                # Query
                layer.self_attention.query = modify_attention_block(
                    layer.self_attention.query,
                    f"{attn_layer_jax_prefix}/query",
                    params_jax,
                    config,
                )
                # Value
                layer.self_attention.value = modify_attention_block(
                    layer.self_attention.value,
                    f"{attn_layer_jax_prefix}/value",
                    params_jax,
                    config,
                )
                # Final dense projection
                layer.dense_output.dense.kernel.assign(
                    tf.Variable(
                        params_jax[f"{attn_layer_jax_prefix}/out/kernel"].reshape(
                            -1, config.projection_dim
                        )
                    )
                )
                layer.dense_output.dense.bias.assign(
                    tf.Variable(params_jax[f"{attn_layer_jax_prefix}/out/bias"])
                )

        idx += 1

In [None]:
idx = 0
for outer_layer in vit_b16_model.layers:
    if isinstance(outer_layer, tf.keras.Model) and outer_layer.name != "projection":
        tf_block = vit_b16_model.get_layer(outer_layer.name)
        jax_block_name = f"encoderblock_{idx}"

        # Layer norm.
        layer_norm_idx = 0
        for layer in tf_block.layers:
            if isinstance(layer, tf.keras.layers.LayerNormalization):

                layer_norm_jax_prefix = (
                    f"Transformer/{jax_block_name}/LayerNorm_{layer_norm_idx}"
                )
                np.testing.assert_allclose(
                    layer.gamma.numpy(), params_jax[f"{layer_norm_jax_prefix}/scale"]
                )
                np.testing.assert_allclose(
                    layer.beta.numpy(), params_jax[f"{layer_norm_jax_prefix}/bias"]
                )
                layer_norm_idx += 2

        # FFN layers.
        ffn_layer_idx = 0
        for layer in tf_block.layers:
            if isinstance(layer, tf.keras.layers.Dense):
                dense_layer_jax_prefix = (
                    f"Transformer/{jax_block_name}/MlpBlock_3/Dense_{ffn_layer_idx}"
                )
                np.testing.assert_allclose(
                    layer.kernel.numpy(), params_jax[f"{dense_layer_jax_prefix}/kernel"]
                )
                np.testing.assert_allclose(
                    layer.bias.numpy(), params_jax[f"{dense_layer_jax_prefix}/bias"]
                )
                ffn_layer_idx += 1

        # Attention layers.
        for layer in tf_block.layers:
            attn_layer_jax_prefix = (
                f"Transformer/{jax_block_name}/MultiHeadDotProductAttention_1"
            )
            if isinstance(layer, mha.TFViTAttention):

                # Key
                np.testing.assert_allclose(
                    layer.self_attention.key.kernel.numpy(),
                    params_jax[f"{attn_layer_jax_prefix}/key/kernel"].reshape(
                        config.projection_dim, -1
                    ),
                )
                np.testing.assert_allclose(
                    layer.self_attention.key.bias.numpy(),
                    params_jax[f"{attn_layer_jax_prefix}/key/bias"].reshape(-1),
                )
                # Query
                np.testing.assert_allclose(
                    layer.self_attention.query.kernel.numpy(),
                    params_jax[f"{attn_layer_jax_prefix}/query/kernel"].reshape(
                        config.projection_dim, -1
                    ),
                )
                np.testing.assert_allclose(
                    layer.self_attention.query.bias.numpy(),
                    params_jax[f"{attn_layer_jax_prefix}/query/bias"].reshape(-1),
                )
                # Value
                np.testing.assert_allclose(
                    layer.self_attention.value.kernel.numpy(),
                    params_jax[f"{attn_layer_jax_prefix}/value/kernel"].reshape(
                        config.projection_dim, -1
                    ),
                )
                np.testing.assert_allclose(
                    layer.self_attention.value.bias.numpy(),
                    params_jax[f"{attn_layer_jax_prefix}/value/bias"].reshape(-1),
                )

                # Final dense projection
                np.testing.assert_allclose(
                    layer.dense_output.dense.kernel.numpy(),
                    params_jax[f"{attn_layer_jax_prefix}/out/kernel"].reshape(
                        -1, config.projection_dim
                    ),
                )
                np.testing.assert_allclose(
                    layer.dense_output.dense.bias.numpy(),
                    tf.Variable(params_jax[f"{attn_layer_jax_prefix}/out/bias"]),
                )

        idx += 1

## Verification

Largely taken from here: https://github.com/sayakpaul/BiT-jax2tf/blob/main/convert_jax_weights_tf.ipynb.

In [None]:
import requests
from PIL import Image
from io import BytesIO

In [None]:
def preprocess_image(image):
    image = np.array(image)
    image_resized = tf.image.resize(image, (224, 224))
    image_resized = tf.cast(image_resized, tf.float32)
    image_resized = (image_resized - 127.5) / 127.5
    return tf.expand_dims(image_resized, 0).numpy()

def load_image_from_url(url):
    response = requests.get(url)
    image = Image.open(BytesIO(response.content))
    image = preprocess_image(image)
    return image

!wget https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt -O ilsvrc2012_wordnet_lemmas.txt

In [None]:
with open("ilsvrc2012_wordnet_lemmas.txt", "r") as f:
    lines = f.readlines()
imagenet_int_to_str = [line.rstrip() for line in lines]

img_url = "https://p0.pikrepo.com/preview/853/907/close-up-photo-of-gray-elephant.jpg"
image = load_image_from_url(img_url)

In [None]:
predictions = vit_b16_model.predict(image)
logits = predictions[0]
predicted_label = imagenet_int_to_str[int(np.argmax(logits))]
expected_label = "Indian_elephant, Elephas_maximus"
assert (
    predicted_label == expected_label
), f"Expected {expected_label} but was {predicted_label}"

In [None]:
vit_b16_model.save("vit_b16_patch16_224")