In [None]:
import jax
import jax.numpy as jnp

import json
from PIL import Image
import matplotlib.pyplot as plt

import tensorflow as tf
from flax_models.vgg import build_vgg16

In [None]:
!wget https://cdn.pixabay.com/photo/2013/05/29/22/25/elephant-114543_960_720.jpg
!wget https://github.com/pytorch/hub/raw/master/images/dog.jpg
!wget https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json

In [None]:
model = tf.keras.applications.VGG16(
    include_top=True,
    weights="imagenet",
    input_tensor=None,
    input_shape=None,
    pooling=None,
    classes=1000,
    classifier_activation="softmax",
)

In [None]:
jax_model, jax_params = build_vgg16(show_parameter_overview=True)

In [None]:
print([weight.name for weight in model.weights])

In [None]:
jax_params["params"].keys(), jax_params["params"]["VGGBlock_0"].keys(), jax_params["params"]["VGGBlock_0"]["Conv_0"].keys()

In [None]:
pre_trained_params = {
    "params": {
        "VGGBlock_0": {
            "Conv_0":  {
                "kernel": model.weights[0].numpy(),
                "bias": model.weights[1].numpy()
            },
            "Conv_1":  {
                "kernel": model.weights[2].numpy(),
                "bias": model.weights[3].numpy()
            },
        },
        "VGGBlock_1": {
            "Conv_0":  {
                "kernel": model.weights[4].numpy(),
                "bias": model.weights[5].numpy()
            },
            "Conv_1":  {
                "kernel": model.weights[6].numpy(),
                "bias": model.weights[7].numpy()
            },
        },
        "VGGBlock_2": {
            "Conv_0":  {
                "kernel": model.weights[8].numpy(),
                "bias": model.weights[9].numpy()
            },
            "Conv_1":  {
                "kernel": model.weights[10].numpy(),
                "bias": model.weights[11].numpy()
            },
            "Conv_2":  {
                "kernel": model.weights[12].numpy(),
                "bias": model.weights[13].numpy()
            },
        },
        "VGGBlock_3": {
            "Conv_0":  {
                "kernel": model.weights[14].numpy(),
                "bias": model.weights[15].numpy()
            },
            "Conv_1":  {
                "kernel": model.weights[16].numpy(),
                "bias": model.weights[17].numpy()
            },
            "Conv_2":  {
                "kernel": model.weights[18].numpy(),
                "bias": model.weights[19].numpy()
            },
        },
        "VGGBlock_4": {
            "Conv_0":  {
                "kernel": model.weights[20].numpy(),
                "bias": model.weights[21].numpy()
            },
            "Conv_1":  {
                "kernel": model.weights[22].numpy(),
                "bias": model.weights[23].numpy()
            },
            "Conv_2":  {
                "kernel": model.weights[24].numpy(),
                "bias": model.weights[25].numpy()
            },
        },
        "Dense_0":  {
            "kernel": model.weights[26].numpy(),
            "bias": model.weights[27].numpy()
        },
        "Dense_1":  {
            "kernel": model.weights[28].numpy(),
            "bias": model.weights[29].numpy()
        },
        "Dense_2":  {
            "kernel": model.weights[30].numpy(),
            "bias": model.weights[31].numpy()
        },
    }
}

In [None]:
image = Image.open('elephant-114543_960_720.jpg')

plt.imshow(image)
plt.show()

image = image.resize((224, 224))
x = tf.keras.preprocessing.image.img_to_array(image)
x = tf.keras.applications.vgg16.preprocess_input(x)
x = jnp.expand_dims(x, axis=0)

key = jax.random.PRNGKey(0)

out = jax_model.apply(pre_trained_params, x)
top5_probs, top5_classes = jax.lax.top_k(out, k=5)
top5_probs = jnp.squeeze(top5_probs, axis=0)
top5_classes = jnp.squeeze(top5_classes, axis=0)

labels = json.load(open('imagenet-simple-labels.json'))
for i in range(top5_classes.shape[0]):
    print(labels[top5_classes[i]], top5_probs[i])

In [None]:
image = Image.open('dog.jpg')

plt.imshow(image)
plt.show()

image = image.resize((224, 224))
x = tf.keras.preprocessing.image.img_to_array(image)
x = tf.keras.applications.vgg16.preprocess_input(x)
x = jnp.expand_dims(x, axis=0)

key = jax.random.PRNGKey(0)

out = jax_model.apply(pre_trained_params, x)
top5_probs, top5_classes = jax.lax.top_k(out, k=5)
top5_probs = jnp.squeeze(top5_probs, axis=0)
top5_classes = jnp.squeeze(top5_classes, axis=0)

labels = json.load(open('imagenet-simple-labels.json'))
for i in range(top5_classes.shape[0]):
    print(labels[top5_classes[i]], top5_probs[i])