<a href="https://colab.research.google.com/github/sayakpaul/ConvNeXt-TF/blob/main/notebooks/classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Setup

In [None]:
!wget https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt

In [None]:
import tensorflow as tf
import tensorflow_hub as hub
from tensorflow import keras


from PIL import Image
from io import BytesIO

import matplotlib.pyplot as plt
import numpy as np
import requests

## Image preprocessing utilities 

In [None]:
crop_layer = keras.layers.CenterCrop(224, 224)
norm_layer = keras.layers.Normalization(
    mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
    variance=[(0.229 * 255) ** 2, (0.224 * 255) ** 2, (0.225 * 255) ** 2],
)


def preprocess_image(image, size=224):
    image = np.array(image)
    image_resized = tf.expand_dims(image, 0)
    
    if size == 224:
        image_resized = tf.image.resize(image_resized, (256, 256), method="bicubic")
        image_resized = crop_layer(image_resized)
    elif size == 384:
        image_resized = tf.image.resize(image, (size, size), method="bicubic")
    
    return norm_layer(image_resized).numpy()
    

def load_image_from_url(url):
    # Credit: Willi Gierke
    response = requests.get(url)
    image = Image.open(BytesIO(response.content))
    preprocessed_image = preprocess_image(image)
    return image, preprocessed_image

## Load ImageNet-1k labels and a demo image

In [None]:
with open("ilsvrc2012_wordnet_lemmas.txt", "r") as f:
    lines = f.readlines()
imagenet_int_to_str = [line.rstrip() for line in lines]

img_url = "https://p0.pikrepo.com/preview/853/907/close-up-photo-of-gray-elephant.jpg"
image, preprocessed_image = load_image_from_url(img_url)

plt.imshow(image)
plt.show()

## Run inference

In [None]:
model_url = "https://tfhub.dev/sayakpaul/convnext_tiny_1k_224/1"

classification_model = tf.keras.Sequential(
    [hub.KerasLayer(model_url)]
)  
predictions = classification_model.predict(preprocessed_image)
predicted_label = imagenet_int_to_str[int(np.argmax(predictions))]
print(predicted_label)