<a href="https://colab.research.google.com/github/sayakpaul/BiT-jax2tf/blob/main/convert_jax_weights_tf.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This notebook shows how to instantiate [BiT-ResNet models](https://arxiv.org/abs/1912.11370) in TensorFlow using code from the official repository [google-research/big_transfer](https://github.com/google-research/big_transfer) and load the original JAX weights into them. 

_**Note**: This notebook is authored by [Willi Gierke](https://ch.linkedin.com/in/willi-gierke) from Google. An initial version of the notebook was developed by Sayak Paul._

In [None]:
# For demonstration purposes, we will be operating with a BiT-ResNet152x2 model.
!wget https://storage.googleapis.com/bit_models/distill/R152x2_T_384.npz

!git clone --depth 1 https://github.com/google-research/big_transfer

import sys

sys.path.append("big_transfer")

from bit_tf2 import models
import tensorflow as tf
import numpy as np

from PIL import Image
from io import BytesIO
import requests


def preprocess_image(image):
    image = np.array(image)
    # Resize to (384, 384).
    image_resized = tf.image.resize(image, (384, 384))
    image_resized = tf.cast(image_resized, tf.float32)
    image_resized = (image_resized - 127.5) / 127.5
    return tf.expand_dims(image_resized, 0).numpy()


def load_image_from_url(url):
    """Returns an image with shape [1, height, width, num_channels]."""
    response = requests.get(url)
    image = Image.open(BytesIO(response.content))
    image = preprocess_image(image)
    return image


def assert_valid_variables(model):
    """Raises an error if a weight only contains 0. or 1."""
    for i, layer in enumerate(model.layers):
        print(f"Layer {i}: {layer.name}")
        if not "layers" in dir(layer):
            print(f"{layer.name} has no .layers")
            continue
        for j, sublayer in enumerate(layer.layers):
            print(f"Sublayer {j}: {sublayer.name}")
            for w in sublayer.get_weights():
                print(w.shape)
                if (w == 1.0).all() or (w == 0.0).all():
                    raise RuntimeError(f"PROBLEM in {layer.name}.{sublayer.name}: {w}")

--2021-08-25 04:33:59--  https://storage.googleapis.com/bit_models/distill/R152x2_T_384.npz
Resolving storage.googleapis.com (storage.googleapis.com)... 64.233.170.128, 74.125.31.128, 173.194.210.128, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|64.233.170.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 945485848 (902M) [application/octet-stream]
Saving to: ‘R152x2_T_384.npz’


2021-08-25 04:34:20 (46.8 MB/s) - ‘R152x2_T_384.npz’ saved [945485848/945485848]

Cloning into 'big_transfer'...
remote: Enumerating objects: 31, done.[K
remote: Counting objects: 100% (31/31), done.[K
remote: Compressing objects: 100% (27/27), done.[K
remote: Total 31 (delta 1), reused 23 (delta 1), pack-reused 0[K
Unpacking objects: 100% (31/31), done.


In [None]:
# Load the labels.
!wget https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt

with open("ilsvrc2012_wordnet_lemmas.txt", "r") as f:
    lines = f.readlines()
imagenet_int_to_str = [line.rstrip() for line in lines]

# Load image (image provided is CC0 licensed)
img_url = "https://p0.pikrepo.com/preview/853/907/close-up-photo-of-gray-elephant.jpg"
image = load_image_from_url(img_url)

model = models.ResnetV2(
    num_units=(3, 8, 36, 3),
    num_outputs=1000,
    filters_factor=8,
    name="resnet",
    trainable=True,
    dtype=tf.float32,
)

model.build((None, 384, 384, 3))
model.summary()

# Print smaller numpy arrays.
np.set_printoptions(threshold=3, edgeitems=1)

--2021-08-25 04:34:29--  https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt
Resolving storage.googleapis.com (storage.googleapis.com)... 172.217.204.128, 172.217.203.128, 173.194.213.128, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|172.217.204.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 21675 (21K) [text/plain]
Saving to: ‘ilsvrc2012_wordnet_lemmas.txt’


2021-08-25 04:34:29 (112 MB/s) - ‘ilsvrc2012_wordnet_lemmas.txt’ saved [21675/21675]

Model: "resnet"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
root_block (Sequential)      (None, 96, 96, 128)       18816     
_________________________________________________________________
block1 (Sequential)          (None, 96, 96, 512)       855808    
_________________________________________________________________
block2 (Sequential)          (None, 48, 48, 1024)      9329664   
_

In [None]:
# Load the weights.
with open("R152x2_T_384.npz", "rb") as f:
    params_tf = np.load(f)
    params_tf = dict(zip(params_tf.keys(), params_tf.values()))

In [None]:
# Assign the weights of each block to the matching TF variables. Check params_tf for details.
units_by_block_nr = {1: 3, 2: 8, 3: 36, 4: 3}

for block_nr, units in units_by_block_nr.items():
    for unit_nr in range(units):
        model.layers[block_nr].layers[unit_nr]._unit_a.layers[0]._beta.assign(
            tf.Variable(
                params_tf[
                    f"resnet/block{block_nr}/unit{unit_nr + 1:02d}/a/group_norm/beta"
                ]
            )
        )
        model.layers[block_nr].layers[unit_nr]._unit_a.layers[0]._gamma.assign(
            tf.Variable(
                params_tf[
                    f"resnet/block{block_nr}/unit{unit_nr + 1:02d}/a/group_norm/gamma"
                ]
            )
        )
        var_name = (
            f"resnet/block{block_nr}/unit{unit_nr + 1:02d}/a/standardized_conv2d/kernel"
        )
        if var_name in params_tf:
            model.layers[block_nr].layers[unit_nr]._unit_a_conv.kernel.assign(
                tf.Variable(params_tf[var_name])
            )

        var_name = f"resnet/block{block_nr}/unit{unit_nr + 1:02d}/a/proj/standardized_conv2d/kernel"
        if var_name in params_tf:
            model.layers[block_nr].layers[unit_nr]._proj.kernel.assign(
                tf.Variable(params_tf[var_name])
            )

        model.layers[block_nr].layers[unit_nr]._unit_b.layers[0]._beta.assign(
            tf.Variable(
                params_tf[
                    f"resnet/block{block_nr}/unit{unit_nr + 1:02d}/b/group_norm/beta"
                ]
            )
        )
        model.layers[block_nr].layers[unit_nr]._unit_b.layers[0]._gamma.assign(
            tf.Variable(
                params_tf[
                    f"resnet/block{block_nr}/unit{unit_nr + 1:02d}/b/group_norm/gamma"
                ]
            )
        )
        var_name = (
            f"resnet/block{block_nr}/unit{unit_nr + 1:02d}/b/standardized_conv2d/kernel"
        )
        if var_name in params_tf:
            model.layers[block_nr].layers[unit_nr]._unit_b.layers[-1].kernel.assign(
                tf.Variable(params_tf[var_name])
            )

        model.layers[block_nr].layers[unit_nr]._unit_c.layers[0]._beta.assign(
            tf.Variable(
                params_tf[
                    f"resnet/block{block_nr}/unit{unit_nr + 1:02d}/c/group_norm/beta"
                ]
            )
        )
        model.layers[block_nr].layers[unit_nr]._unit_c.layers[0]._gamma.assign(
            tf.Variable(
                params_tf[
                    f"resnet/block{block_nr}/unit{unit_nr + 1:02d}/c/group_norm/gamma"
                ]
            )
        )
        var_name = (
            f"resnet/block{block_nr}/unit{unit_nr + 1:02d}/c/standardized_conv2d/kernel"
        )
        if var_name in params_tf:
            model.layers[block_nr].layers[unit_nr]._unit_c.layers[-1].kernel.assign(
                tf.Variable(params_tf[var_name])
            )

In [None]:
# Set the variables not included in the blocks.
model.layers[0].layers[1].kernel.assign(
    tf.Variable(params_tf["resnet/root_block/standardized_conv2d/kernel"])
)

model.layers[5]._gamma.assign(tf.Variable(params_tf["resnet/group_norm/gamma"]))
model.layers[5]._beta.assign(tf.Variable(params_tf["resnet/group_norm/beta"]))

model.layers[-1].kernel.assign(
    tf.Variable(params_tf["resnet/head/conv2d/kernel"].reshape(4096, 1000))
)
model.layers[-1].bias.assign(tf.Variable(params_tf["resnet/head/conv2d/bias"]))

<tf.Variable 'UnreadVariable' shape=(1000,) dtype=float32, numpy=array([7.7493743e-05, ..., 8.2581966e-05], dtype=float32)>

In [None]:
# Verify that it works.
logits = model.predict(image)
s = tf.nn.softmax(logits, 1)
assert (
    imagenet_int_to_str[tf.argmax(s, -1).numpy()[0]]
    == "Indian_elephant, Elephas_maximus"
)