## Imports

In [None]:
import tensorflow as tf
import tensorflow_hub as hub

from PIL import Image
from io import BytesIO
import matplotlib.pyplot as plt
import numpy as np
import requests

## Get the `SavedModel`

In [None]:
# TODO: Remove this when Hub models are published. 
!gsutil cp -r  gs://flowers-experimental/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224 .

## Image preprocessing utilities (credits: [Willi Gierke](https://ch.linkedin.com/in/willi-gierke))

In [None]:
def preprocess_image(image):
    image = np.array(image)
    image_resized = tf.image.resize(image, (224, 224))
    image_resized = tf.cast(image_resized, tf.float32)
    image_resized = (image_resized - 127.5) / 127.5
    return tf.expand_dims(image_resized, 0).numpy()

def load_image_from_url(url):
    response = requests.get(url)
    image = Image.open(BytesIO(response.content))
    image = preprocess_image(image)
    return image

!wget https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt

## Load image and infer

In [None]:
with open("ilsvrc2012_wordnet_lemmas.txt", "r") as f:
    lines = f.readlines()
imagenet_int_to_str = [line.rstrip() for line in lines]

img_url = "https://p0.pikrepo.com/preview/853/907/close-up-photo-of-gray-elephant.jpg"
image = load_image_from_url(img_url)

plt.imshow((image[0] + 1) / 2)
plt.show()

In [None]:
model_dir = "S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224"

classification_model = tf.keras.Sequential(
    [hub.KerasLayer(model_dir)]
)  # TODO: Update when models are published.
predictions = classification_model.predict(image)
predicted_label = imagenet_int_to_str[int(np.argmax(predictions))]
predicted_label