In [None]:
import sys

sys.path.append("../")

In [None]:
import jax
import jax.numpy as jnp

import os
import json
import wandb
from PIL import Image
import matplotlib.pyplot as plt

import tensorflow as tf
from flax_models.vgg import build_vgg16

In [None]:
wandb.init(project="flax-vision-models", entity="geekyrakshit", job_type="inference")

In [None]:
model, params = build_vgg16(show_parameter_overview=True, pretrained=True)

In [None]:
artifact = wandb.use_artifact('geekyrakshit/flax-vision-models/imagenet-simple-labels:v0')
artifact_dir = artifact.download()
imagenet_labels_file = os.path.join(artifact_dir, "imagenet-simple-labels.json")

In [None]:
def predict_jax(model, params, prepocessing_fn, image_file):
    image = Image.open(image_file)

    plt.imshow(image)
    plt.show()

    image = image.resize((224, 224))
    x = tf.keras.preprocessing.image.img_to_array(image)
    x = prepocessing_fn(x)
    x = jnp.expand_dims(x, axis=0)

    out = model.apply(params, x)
    top5_probs, top5_classes = jax.lax.top_k(out, k=5)
    top5_probs = jnp.squeeze(top5_probs, axis=0)
    top5_classes = jnp.squeeze(top5_classes, axis=0)

    labels = json.load(open(imagenet_labels_file))
    for i in range(top5_classes.shape[0]):
        print(labels[top5_classes[i]], top5_probs[i])

In [None]:
predict_jax(
    model,
    params,
    tf.keras.applications.vgg19.preprocess_input,
    '../dog.jpg'
)

In [None]:
wandb.finish()