In [1]:
import timm
import torch
import numpy as np
import tensorflow as tf

from copy import deepcopy
from ml_collections import ConfigDict

from typing import Dict

In [2]:
import sys

sys.path.append("..")

from vit.model_configs import base_config
from vit.layers import mha
from vit.deit_models import ViTDistilled

In [3]:
deit_tiny_distilled_patch16_224 = timm.create_model(
    model_name="deit_tiny_distilled_patch16_224", num_classes=1000, pretrained=True
)

In [4]:
"dist_token" in deit_tiny_distilled_patch16_224.state_dict()

True

In [5]:
distilled_tiny_tf_config = base_config.get_config(
    name="deit_tiny_distilled_patch16_224"
)
deit_tiny_distilled_patch16_224_tf = ViTDistilled(distilled_tiny_tf_config)

dummy_inputs = tf.ones((2, 224, 224, 3))
deit_tiny_distilled_patch16_224_tf(dummy_inputs)[0].shape

2022-03-27 09:14:15.838939: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


TensorShape([2, 1000])

In [6]:
assert deit_tiny_distilled_patch16_224_tf.count_params() == sum(
    p.numel() for p in deit_tiny_distilled_patch16_224.parameters()
)

In [7]:
deit_tiny_distilled_patch16_224_dict = deit_tiny_distilled_patch16_224.state_dict()
deit_tiny_distilled_patch16_224_dict = {
    k: deit_tiny_distilled_patch16_224_dict[k].numpy()
    for k in deit_tiny_distilled_patch16_224_dict
}

In [8]:
def conv_transpose(w: np.ndarray):
    return w.transpose(2, 3, 1, 0)


def modify_attention_block(qkv: np.ndarray, config: ConfigDict):
    if qkv.ndim == 2:
        qkv_tf = qkv.T
        q = qkv_tf[:, : config.projection_dim]
        k = qkv_tf[:, config.projection_dim : 2 * config.projection_dim]
        v = qkv_tf[:, -config.projection_dim :]
    elif qkv.ndim == 1:
        qkv_tf = deepcopy(qkv)
        q = qkv_tf[: config.projection_dim]
        k = qkv_tf[config.projection_dim : 2 * config.projection_dim]
        v = qkv_tf[-config.projection_dim :]
    else:
        raise ValueError("NumPy arrays with either two or one dimension are allowed.")
    return q, k, v


def get_tf_qkv(pt_component: str, pt_params: Dict[str, np.ndarray], config: ConfigDict):
    qkv_weight = pt_params[f"{pt_component}.qkv.weight"]
    qkv_bias = pt_params[f"{pt_component}.qkv.bias"]

    q_w, k_w, v_w = modify_attention_block(qkv_weight, config)
    q_b, k_b, v_b = modify_attention_block(qkv_bias, config)

    return (q_w, k_w, v_w), (q_b, k_b, v_b)


def modify_tf_block(
    tf_component: tf.keras.layers.Layer,
    pt_weight: np.ndarray,
    pt_bias: np.ndarray,
    config: ConfigDict,
    is_attn: bool = False,
):

    pt_weight = (
        conv_transpose(pt_weight)
        if isinstance(tf_component, tf.keras.layers.Conv2D)
        else pt_weight
    )
    pt_weight = (
        pt_weight.transpose()
        if isinstance(tf_component, tf.keras.layers.Dense) and not is_attn
        else pt_weight
    )
    
    tf_component.kernel.assign(tf.Variable(pt_weight))
    tf_component.bias.assign(tf.Variable(pt_bias))
    return tf_component

In [9]:
# Projection.

deit_tiny_distilled_patch16_224_tf.layers[0].layers[0] = modify_tf_block(
    deit_tiny_distilled_patch16_224_tf.layers[0].layers[0],
    deit_tiny_distilled_patch16_224_dict["patch_embed.proj.weight"],
    deit_tiny_distilled_patch16_224_dict["patch_embed.proj.bias"],
    distilled_tiny_tf_config,
)

In [10]:
deit_tiny_distilled_patch16_224_dict["pos_embed"].shape

(1, 198, 192)

In [11]:
deit_tiny_distilled_patch16_224_tf.positional_embedding.assign(
    tf.Variable(deit_tiny_distilled_patch16_224_dict["pos_embed"])
)
print(" ")

 


In [12]:
deit_tiny_distilled_patch16_224_tf.cls_token.assign(
    tf.Variable(deit_tiny_distilled_patch16_224_dict["cls_token"])
)
deit_tiny_distilled_patch16_224_tf.dist_token.assign(
    tf.Variable(deit_tiny_distilled_patch16_224_dict["dist_token"])
)
print(" ")

 


In [13]:
deit_tiny_distilled_patch16_224_tf.summary()

Model: "vi_t_distilled"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 projection (Sequential)     (2, 196, 192)             147648    
                                                                 
 transformer_block_0 (Functi  [(None, 198, 192),       444864    
 onal)                        (None, 3, None, None)]             
                                                                 
 transformer_block_1 (Functi  [(None, 198, 192),       444864    
 onal)                        (None, 3, None, None)]             
                                                                 
 transformer_block_2 (Functi  [(None, 198, 192),       444864    
 onal)                        (None, 3, None, None)]             
                                                                 
 transformer_block_3 (Functi  [(None, 198, 192),       444864    
 onal)                        (None, 3, None, None)]

In [14]:
# Final layer norm layer.
deit_tiny_distilled_patch16_224_tf.layers[-3].gamma.assign(
    tf.Variable(deit_tiny_distilled_patch16_224_dict["norm.weight"])
)
deit_tiny_distilled_patch16_224_tf.layers[-3].beta.assign(
    tf.Variable(deit_tiny_distilled_patch16_224_dict["norm.bias"])
)

print(" ")

 


In [15]:
deit_tiny_distilled_patch16_224

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 192, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=192, out_features=576, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=192, out_features=192, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=192, out_features=768, bias=True)
        (act): GELU()
        (drop1): Dropout(p=0.0, inplace=False)
        (fc2): Linear(in_features=768, out_features=192, bias=True)
        (drop2): Dropout(p=0.0, inplace=False)
      )
    )
    (1): Block(
      (norm1): LayerNorm((192,), ep

In [16]:
# Head layers.
head_layer = deit_tiny_distilled_patch16_224_tf.get_layer("classification_head")
deit_tiny_distilled_patch16_224_tf.layers[-2] = modify_tf_block(
    head_layer,
    deit_tiny_distilled_patch16_224_dict["head.weight"],
    deit_tiny_distilled_patch16_224_dict["head.bias"],
    distilled_tiny_tf_config,
)

head_dist_layer = deit_tiny_distilled_patch16_224_tf.get_layer("distillation_head")
deit_tiny_distilled_patch16_224_tf.layers[-1] = modify_tf_block(
    head_dist_layer,
    deit_tiny_distilled_patch16_224_dict["head_dist.weight"],
    deit_tiny_distilled_patch16_224_dict["head_dist.bias"],
    distilled_tiny_tf_config,
)

In [17]:
idx = 0

for outer_layer in deit_tiny_distilled_patch16_224_tf.layers:
    if isinstance(outer_layer, tf.keras.Model) and outer_layer.name != "projection":
        tf_block = deit_tiny_distilled_patch16_224_tf.get_layer(outer_layer.name)
        pt_block_name = f"blocks.{idx}"

        # LayerNorm layers.
        layer_norm_idx = 1
        for layer in tf_block.layers:
            if isinstance(layer, tf.keras.layers.LayerNormalization):
                layer_norm_pt_prefix = f"{pt_block_name}.norm{layer_norm_idx}"
                layer.gamma.assign(
                    tf.Variable(
                        deit_tiny_distilled_patch16_224_dict[
                            f"{layer_norm_pt_prefix}.weight"
                        ]
                    )
                )
                layer.beta.assign(
                    tf.Variable(
                        deit_tiny_distilled_patch16_224_dict[
                            f"{layer_norm_pt_prefix}.bias"
                        ]
                    )
                )
                layer_norm_idx += 1

        # FFN layers.
        ffn_layer_idx = 1
        for layer in tf_block.layers:
            if isinstance(layer, tf.keras.layers.Dense):
                dense_layer_pt_prefix = f"{pt_block_name}.mlp.fc{ffn_layer_idx}"
                layer = modify_tf_block(
                    layer,
                    deit_tiny_distilled_patch16_224_dict[
                        f"{dense_layer_pt_prefix}.weight"
                    ],
                    deit_tiny_distilled_patch16_224_dict[
                        f"{dense_layer_pt_prefix}.bias"
                    ],
                    distilled_tiny_tf_config,
                )
                ffn_layer_idx += 1

        # Attention layer.
        for layer in tf_block.layers:
            (q_w, k_w, v_w), (q_b, k_b, v_b) = get_tf_qkv(
                f"{pt_block_name}.attn",
                deit_tiny_distilled_patch16_224_dict,
                distilled_tiny_tf_config,
            )

            if isinstance(layer, mha.TFViTAttention):
                # Key
                layer.self_attention.key = modify_tf_block(
                    layer.self_attention.key,
                    k_w,
                    k_b,
                    distilled_tiny_tf_config,
                    is_attn=True,
                )
                # Query
                layer.self_attention.query = modify_tf_block(
                    layer.self_attention.query,
                    q_w,
                    q_b,
                    distilled_tiny_tf_config,
                    is_attn=True,
                )
                # Value
                layer.self_attention.value = modify_tf_block(
                    layer.self_attention.value,
                    v_w,
                    v_b,
                    distilled_tiny_tf_config,
                    is_attn=True,
                )
                # Final dense projection
                layer.dense_output.dense = modify_tf_block(
                    layer.dense_output.dense,
                    deit_tiny_distilled_patch16_224_dict[
                        f"{pt_block_name}.attn.proj.weight"
                    ],
                    deit_tiny_distilled_patch16_224_dict[
                        f"{pt_block_name}.attn.proj.bias"
                    ],
                    distilled_tiny_tf_config,
                )

        idx += 1

In [18]:
import requests
from PIL import Image
from io import BytesIO

In [19]:
norm_layer = tf.keras.layers.Normalization(
    mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
    variance=[(0.229 * 255) ** 2, (0.224 * 255) ** 2, (0.225 * 255) ** 2],
)


def preprocess_image(image):
    image = np.array(image)
    image_resized = tf.image.resize(image, (224, 224))
    image_resized = tf.cast(image_resized, tf.float32)
    
    image_resized = tf.expand_dims(image_resized, 0)
    return norm_layer(image_resized)

def load_image_from_url(url):
    response = requests.get(url)
    image = Image.open(BytesIO(response.content))
    image = preprocess_image(image)
    return image

# !wget https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt -O ilsvrc2012_wordnet_lemmas.txt

In [20]:
with open("ilsvrc2012_wordnet_lemmas.txt", "r") as f:
    lines = f.readlines()
imagenet_int_to_str = [line.rstrip() for line in lines]

img_url = "https://p0.pikrepo.com/preview/853/907/close-up-photo-of-gray-elephant.jpg"
image = load_image_from_url(img_url)

In [21]:
predictions = deit_tiny_distilled_patch16_224_tf.predict(image)
logits = predictions[0]
predicted_label = imagenet_int_to_str[int(np.argmax(logits))]
expected_label = "Indian_elephant, Elephas_maximus"
assert (
    predicted_label == expected_label
), f"Expected {expected_label} but was {predicted_label}"

In [22]:
deit_tiny_distilled_patch16_224_tf.save("deit_tiny_distilled_patch16_224_tf")

2022-03-27 09:14:28.456820: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.


INFO:tensorflow:Assets written to: deit_tiny_distilled_patch16_224_tf/assets


INFO:tensorflow:Assets written to: deit_tiny_distilled_patch16_224_tf/assets
