<a href="https://colab.research.google.com/github/soumik12345/flax-vision-models/blob/main/notebooks/inference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Install Flax Vision Models
!pip install -q git+https://github.com/soumik12345/flax-vision-models
# Fetch image for inference
!wget https://github.com/pytorch/hub/raw/master/images/dog.jpg

In [None]:
import jax
import jax.numpy as jnp

import os
import wandb
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

from flax_vision_models.vgg import build_vgg16, preprocess as preprocess_fn
from flax_vision_models.utils import decode_probabilities_imagenet

In [None]:
model, params = build_vgg16(show_parameter_overview=True, pretrained=True)

In [None]:
def infer(x, model, params):
    out = model.apply(params, x)
    top5_probs, top5_classes = jax.lax.top_k(out, k=5)
    top5_probs = jnp.squeeze(top5_probs, axis=0)
    top5_classes = jnp.squeeze(top5_classes, axis=0)
    return top5_classes, top5_probs


def predict_jax(model, params, prepocessing_fn, image_file):
    image = Image.open(image_file)

    plt.imshow(image)
    plt.show()

    image = image.resize((224, 224))
    x = np.array(image)
    x = prepocessing_fn(x)
    x = jnp.expand_dims(x, axis=0)

    top5_classes, top5_probs = infer(x, model, params)
    topk_labels, topk_probabilities = decode_probabilities_imagenet(top5_classes, top5_probs)
    for i in range(len(topk_labels)):
        print(topk_labels[i], topk_probabilities[i])

In [None]:
predict_jax(model, params, preprocess_fn, './dog.jpg')