In [1]:
import timm

import tensorflow as tf
import numpy as np

In [2]:
import sys

sys.path.append("..")

from swins import SwinTransformer
from swins.layers import *
from swins.blocks import *
from utils import helpers

In [3]:
cfg = dict(
    patch_size=4,
    window_size=7,
    embed_dim=96,
    depths=(2, 2, 6, 2),
    num_heads=(3, 6, 12, 24),
)

In [4]:
swin_tiny_patch4_window7_224_tf = SwinTransformer(
    name="swin_tiny_patch4_window7_224", **cfg
)
random_tensor = tf.random.normal((2, 224, 224, 3))
outputs = swin_tiny_patch4_window7_224_tf(random_tensor, training=False)
print("Swin TF model created.")

2022-05-08 18:23:21.505079: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


Swin TF model created.


In [5]:
swin_tiny_patch4_window7_224_pt = timm.create_model(
    model_name="swin_tiny_patch4_window7_224", pretrained=True
)
print("Swin PT model created.")
print("Number of parameters:")
num_params = sum(p.numel() for p in swin_tiny_patch4_window7_224_pt.parameters())
print(num_params / 1e6)

assert swin_tiny_patch4_window7_224_tf.count_params() == num_params

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


Swin PT model created.
Number of parameters:
28.288354


In [6]:
state_dict = swin_tiny_patch4_window7_224_pt.state_dict()
np_state_dict = {k: state_dict[k].numpy() for k in state_dict}

In [7]:
# Projection.
swin_tiny_patch4_window7_224_tf.projection.layers[0] = helpers.modify_tf_block(
    swin_tiny_patch4_window7_224_tf.projection.layers[0],
    np_state_dict["patch_embed.proj.weight"],
    np_state_dict["patch_embed.proj.bias"],
)
swin_tiny_patch4_window7_224_tf.projection.layers[2] = helpers.modify_tf_block(
    swin_tiny_patch4_window7_224_tf.projection.layers[2],
    np_state_dict["patch_embed.norm.weight"],
    np_state_dict["patch_embed.norm.bias"],
)

In [8]:
# Layer norm layers.
ln_idx = -2
swin_tiny_patch4_window7_224_tf.layers[ln_idx] = helpers.modify_tf_block(
    swin_tiny_patch4_window7_224_tf.layers[ln_idx],
    np_state_dict["norm.weight"],
    np_state_dict["norm.bias"],
)

# Head layers.
head_layer = swin_tiny_patch4_window7_224_tf.get_layer("classification_head")
swin_tiny_patch4_window7_224_tf.layers[-1] = helpers.modify_tf_block(
    head_layer,
    np_state_dict["head.weight"],
    np_state_dict["head.bias"],
)

In [9]:
list(filter(lambda x: "layers.0" in x, np_state_dict.keys()))

['layers.0.blocks.0.norm1.weight',
 'layers.0.blocks.0.norm1.bias',
 'layers.0.blocks.0.attn.relative_position_bias_table',
 'layers.0.blocks.0.attn.relative_position_index',
 'layers.0.blocks.0.attn.qkv.weight',
 'layers.0.blocks.0.attn.qkv.bias',
 'layers.0.blocks.0.attn.proj.weight',
 'layers.0.blocks.0.attn.proj.bias',
 'layers.0.blocks.0.norm2.weight',
 'layers.0.blocks.0.norm2.bias',
 'layers.0.blocks.0.mlp.fc1.weight',
 'layers.0.blocks.0.mlp.fc1.bias',
 'layers.0.blocks.0.mlp.fc2.weight',
 'layers.0.blocks.0.mlp.fc2.bias',
 'layers.0.blocks.1.attn_mask',
 'layers.0.blocks.1.norm1.weight',
 'layers.0.blocks.1.norm1.bias',
 'layers.0.blocks.1.attn.relative_position_bias_table',
 'layers.0.blocks.1.attn.relative_position_index',
 'layers.0.blocks.1.attn.qkv.weight',
 'layers.0.blocks.1.attn.qkv.bias',
 'layers.0.blocks.1.attn.proj.weight',
 'layers.0.blocks.1.attn.proj.bias',
 'layers.0.blocks.1.norm2.weight',
 'layers.0.blocks.1.norm2.bias',
 'layers.0.blocks.1.mlp.fc1.weight',
 

In [10]:
def modify_swin_blocks(pt_weights_prefix, tf_block):
    # Patch merging.
    for layer in tf_block:
        if isinstance(layer, PatchMerging):
            patch_merging_idx = f"{pt_weights_prefix}.downsample"

            layer.reduction = helpers.modify_tf_block(
                layer.reduction,
                np_state_dict[f"{patch_merging_idx}.reduction.weight"],
            )
            layer.norm = helpers.modify_tf_block(
                layer.norm,
                np_state_dict[f"{patch_merging_idx}.norm.weight"],
                np_state_dict[f"{patch_merging_idx}.norm.bias"],
            )

    # Swin layers.
    common_prefix = f"{pt_weights_prefix}.blocks"
    block_idx = 0

    for outer_layer in tf_block:

        layernorm_idx = 1
        mlp_layer_idx = 1

        if isinstance(outer_layer, SwinTransformerBlock):
            for inner_layer in outer_layer.layers:

                # Layer norm.
                if isinstance(inner_layer, tf.keras.layers.LayerNormalization):
                    layer_norm_prefix = (
                        f"{common_prefix}.{block_idx}.norm{layernorm_idx}"
                    )
                    inner_layer.gamma.assign(
                        tf.Variable(np_state_dict[f"{layer_norm_prefix}.weight"])
                    )
                    inner_layer.beta.assign(
                        tf.Variable(np_state_dict[f"{layer_norm_prefix}.bias"])
                    )
                    layernorm_idx += 1

                # Windown attention.
                elif isinstance(inner_layer, WindowAttention):
                    attn_prefix = f"{common_prefix}.{block_idx}.attn"

                    # Relative position.
                    inner_layer.relative_position_bias_table = helpers.modify_tf_block(
                        inner_layer.relative_position_bias_table,
                        np_state_dict[f"{attn_prefix}.relative_position_bias_table"],
                    )
                    inner_layer.relative_position_index = helpers.modify_tf_block(
                        inner_layer.relative_position_index,
                        np_state_dict[f"{attn_prefix}.relative_position_index"],
                    )

                    # QKV.
                    inner_layer.qkv = helpers.modify_tf_block(
                        inner_layer.qkv,
                        np_state_dict[f"{attn_prefix}.qkv.weight"],
                        np_state_dict[f"{attn_prefix}.qkv.bias"],
                    )

                    # Projection.
                    inner_layer.proj = helpers.modify_tf_block(
                        inner_layer.proj,
                        np_state_dict[f"{attn_prefix}.proj.weight"],
                        np_state_dict[f"{attn_prefix}.proj.bias"],
                    )

                # MLP.
                elif isinstance(inner_layer, tf.keras.Model):
                    mlp_prefix = f"{common_prefix}.{block_idx}.mlp"
                    for mlp_layer in inner_layer.layers:
                        if isinstance(mlp_layer, tf.keras.layers.Dense):
                            mlp_layer = helpers.modify_tf_block(
                                mlp_layer,
                                np_state_dict[f"{mlp_prefix}.fc{mlp_layer_idx}.weight"],
                                np_state_dict[f"{mlp_prefix}.fc{mlp_layer_idx}.bias"],
                            )
                            mlp_layer_idx += 1

            block_idx += 1
    return tf_block

In [11]:
_ = modify_swin_blocks(
    "layers.0",
    swin_tiny_patch4_window7_224_tf.layers[2].layers,
)

In [12]:
tf_block = swin_tiny_patch4_window7_224_tf.layers[2].layers
pt_weights_prefix = "layers.0"

# Patch merging.
for layer in tf_block:
    if isinstance(layer, PatchMerging):
        patch_merging_idx = f"{pt_weights_prefix}.downsample"
        np.testing.assert_allclose(
            np_state_dict[f"{patch_merging_idx}.reduction.weight"].transpose(),
            layer.reduction.kernel.numpy(),
        )
        np.testing.assert_allclose(
            np_state_dict[f"{patch_merging_idx}.norm.weight"], layer.norm.gamma.numpy()
        )
        np.testing.assert_allclose(
            np_state_dict[f"{patch_merging_idx}.norm.bias"], layer.norm.beta.numpy()
        )

# Swin layers.
common_prefix = f"{pt_weights_prefix}.blocks"
block_idx = 0

for outer_layer in tf_block:

    layernorm_idx = 1
    mlp_layer_idx = 1

    if isinstance(outer_layer, SwinTransformerBlock):
        for inner_layer in outer_layer.layers:

            # Layer norm.
            if isinstance(inner_layer, tf.keras.layers.LayerNormalization):
                layer_norm_prefix = f"{common_prefix}.{block_idx}.norm{layernorm_idx}"
                np.testing.assert_allclose(
                    np_state_dict[f"{layer_norm_prefix}.weight"],
                    inner_layer.gamma.numpy(),
                )
                np.testing.assert_allclose(
                    np_state_dict[f"{layer_norm_prefix}.bias"], inner_layer.beta.numpy()
                )
                layernorm_idx += 1

            # Windown attention.
            elif isinstance(inner_layer, WindowAttention):
                attn_prefix = f"{common_prefix}.{block_idx}.attn"

                # Relative position.
                np.testing.assert_allclose(
                    np_state_dict[f"{attn_prefix}.relative_position_bias_table"],
                    inner_layer.relative_position_bias_table.numpy(),
                )

                np.testing.assert_allclose(
                    np_state_dict[f"{attn_prefix}.relative_position_index"],
                    inner_layer.relative_position_index.numpy(),
                )

                # QKV.
                np.testing.assert_allclose(
                    np_state_dict[f"{attn_prefix}.qkv.weight"].transpose(),
                    inner_layer.qkv.kernel.numpy(),
                )
                np.testing.assert_allclose(
                    np_state_dict[f"{attn_prefix}.qkv.bias"],
                    inner_layer.qkv.bias.numpy(),
                )

                # Projection.
                np.testing.assert_allclose(
                    np_state_dict[f"{attn_prefix}.proj.weight"].transpose(),
                    inner_layer.proj.kernel.numpy(),
                )
                np.testing.assert_allclose(
                    np_state_dict[f"{attn_prefix}.proj.bias"],
                    inner_layer.proj.bias.numpy(),
                )

            # MLP.
            elif isinstance(inner_layer, tf.keras.Model):
                mlp_prefix = f"{common_prefix}.{block_idx}.mlp"
                for mlp_layer in inner_layer.layers:
                    if isinstance(mlp_layer, tf.keras.layers.Dense):
                        np.testing.assert_allclose(
                            np_state_dict[
                                f"{mlp_prefix}.fc{mlp_layer_idx}.weight"
                            ].transpose(),
                            mlp_layer.kernel.numpy(),
                        )
                        np.testing.assert_allclose(
                            np_state_dict[f"{mlp_prefix}.fc{mlp_layer_idx}.bias"],
                            mlp_layer.bias.numpy(),
                        )

                        mlp_layer_idx += 1

        block_idx += 1

In [13]:
for i in range(len(cfg["depths"])):
    _ = modify_swin_blocks(
        f"layers.{i}",
        swin_tiny_patch4_window7_224_tf.layers[i+2].layers,
    )

In [14]:
import requests
from PIL import Image
from io import BytesIO

import matplotlib.pyplot as plt

In [15]:
input_resolution = 224

crop_layer = tf.keras.layers.CenterCrop(input_resolution, input_resolution)
norm_layer = tf.keras.layers.Normalization(
    mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
    variance=[(0.229 * 255) ** 2, (0.224 * 255) ** 2, (0.225 * 255) ** 2],
)


def preprocess_image(image, size=input_resolution):
    image = np.array(image)
    image_resized = tf.expand_dims(image, 0)
    resize_size = int((256 / 224) * size)
    image_resized = tf.image.resize(
        image_resized, (resize_size, resize_size), method="bicubic"
    )
    image_resized = crop_layer(image_resized)
    return norm_layer(image_resized).numpy()


def load_image_from_url(url):
    # Credit: Willi Gierke
    response = requests.get(url)
    image = Image.open(BytesIO(response.content))
    preprocessed_image = preprocess_image(image)
    return image, preprocessed_image

In [16]:
# !wget https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt -O ilsvrc2012_wordnet_lemmas.txt

In [17]:
with open("ilsvrc2012_wordnet_lemmas.txt", "r") as f:
    lines = f.readlines()
imagenet_int_to_str = [line.rstrip() for line in lines]

img_url = "https://p0.pikrepo.com/preview/853/907/close-up-photo-of-gray-elephant.jpg"
image, preprocessed_image = load_image_from_url(img_url)

In [18]:
predictions = swin_tiny_patch4_window7_224_tf.predict(preprocessed_image)
logits = predictions[0]
predicted_label = imagenet_int_to_str[int(np.argmax(logits))]
expected_label = "Indian_elephant, Elephas_maximus"
assert (
    predicted_label == expected_label
), f"Expected {expected_label} but was {predicted_label}"

In [19]:
all_attn_scores = swin_tiny_patch4_window7_224_tf.get_attention_scores(
    preprocessed_image
)
all_attn_scores.keys()

dict_keys(['swin_stage_0', 'swin_stage_1', 'swin_stage_2', 'swin_stage_3'])

In [20]:
all_attn_scores["swin_stage_3"].keys()

dict_keys(['swin_block_0', 'swin_block_1'])

In [21]:
all_attn_scores["swin_stage_3"]["swin_block_0"].shape

TensorShape([1, 24, 49, 49])

In [22]:
swin_tiny_patch4_window7_224_tf.save("gs://swin-tf/swin_tiny_patch4_window7_224_tf")

2022-05-08 18:23:42.809960: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.


INFO:tensorflow:Assets written to: gs://swin-tf/swin_tiny_patch4_window7_224_tf/assets


INFO:tensorflow:Assets written to: gs://swin-tf/swin_tiny_patch4_window7_224_tf/assets
