## Imports

In [1]:
from tensorflow.keras import layers
from tensorflow import keras
import tensorflow_hub as hub
import tensorflow as tf

from imutils import paths
import json
import re

## Constants

In [2]:
AUTO = tf.data.AUTOTUNE
BATCH_SIZE = 256
IMAGE_SIZE = 224
TF_MODEL_ROOT = "gs://convnext/saved_models"

## Set up ImageNet-1k labels

In [3]:
with open("imagenet_class_index.json", "r") as read_file:
    imagenet_labels = json.load(read_file)

MAPPING_DICT = {}
LABEL_NAMES = {}
for label_id in list(imagenet_labels.keys()):
    MAPPING_DICT[imagenet_labels[label_id][0]] = int(label_id)
    LABEL_NAMES[int(label_id)] = imagenet_labels[label_id][1]

In [4]:
all_val_paths = list(paths.list_images("val"))
all_val_labels = [MAPPING_DICT[x.split("/")[1]] for x in all_val_paths]

all_val_paths[:5], all_val_labels[:5]

(['val/n01751748/ILSVRC2012_val_00031060.JPEG',
  'val/n01751748/ILSVRC2012_val_00013492.JPEG',
  'val/n01751748/ILSVRC2012_val_00033108.JPEG',
  'val/n01751748/ILSVRC2012_val_00021437.JPEG',
  'val/n01751748/ILSVRC2012_val_00025096.JPEG'],
 [65, 65, 65, 65, 65])

## Preprocessing utilities

In [5]:
def load_and_prepare(path, label):
    image = tf.io.read_file(path)
    image = tf.image.decode_png(image, channels=3)
    image = tf.image.resize(image, (256, 256), method="bicubic")
    return image, label

In [6]:
# Reference: https://github.com/facebookresearch/ConvNeXt/blob/main/datasets.py
def get_preprocessing_model(input_size=224):
    preprocessing_model = keras.Sequential()

    preprocessing_model.add(layers.CenterCrop(input_size, input_size))
    preprocessing_model.add(layers.Normalization(
        mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
        variance=[(0.229 * 255) ** 2, (0.224 * 255) ** 2, (0.225 * 255) ** 2],
    ))

    return preprocessing_model

## Prepare `tf.data.Dataset`

In [7]:
preprocessor = get_preprocessing_model()

dataset = tf.data.Dataset.from_tensor_slices((all_val_paths, all_val_labels))
dataset = dataset.map(load_and_prepare, num_parallel_calls=AUTO).batch(BATCH_SIZE)
dataset = dataset.map(lambda x, y: (preprocessor(x), y), num_parallel_calls=AUTO)
dataset = dataset.prefetch(AUTO)
dataset.element_spec

2022-01-31 03:20:05.306146: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-01-31 03:20:05.828828: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 38444 MB memory:  -> device: 0, name: A100-SXM4-40GB, pci bus id: 0000:00:04.0, compute capability: 8.0


(TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name=None),
 TensorSpec(shape=(None,), dtype=tf.int32, name=None))

## Fetch model paths and filter the 224x224 models

In [8]:
model_paths = tf.io.gfile.listdir(TF_MODEL_ROOT)
models_res_224 = [model_path for model_path in model_paths if str(IMAGE_SIZE) in model_path]
p = re.compile('.*_21k_224')
i1k_paths = [path for path in models_res_224 if not p.match(path)]

print(i1k_paths)

['convnext_base_1k_224/', 'convnext_base_21k_1k_224/', 'convnext_large_1k_224/', 'convnext_large_21k_1k_224/', 'convnext_small_1k_224/', 'convnext_tiny_1k_224/', 'convnext_xlarge_21k_1k_224/']


## Run evaluation

In [13]:
def get_model(model_url):
    classification_model = tf.keras.Sequential(
        [
            layers.InputLayer((224, 224, 3)),
            hub.KerasLayer(model_url),
        ]
    )
    return classification_model


def evaluate_model(model_name):
    tb_callback = tf.keras.callbacks.TensorBoard(log_dir=f"logs_{model_name}")
    model_url = TF_MODEL_ROOT + "/" + model_name
    
    model = get_model(model_url)
    model.compile(metrics=["accuracy"])
    _, accuracy = model.evaluate(dataset, callbacks=[tb_callback])
    accuracy = round(accuracy * 100, 4)
    print(f"{model_name}: {accuracy}%.", file=open(f"{model_name.strip('/')}.txt", "w"))

In [14]:
for i1k_path in i1k_paths:
    print(f"Evaluating {i1k_path}.")
    evaluate_model(i1k_path)

Evaluating convnext_base_1k_224/.
Evaluating convnext_base_21k_1k_224/.
Evaluating convnext_large_1k_224/.
Evaluating convnext_large_21k_1k_224/.
Evaluating convnext_small_1k_224/.
Evaluating convnext_tiny_1k_224/.
Evaluating convnext_xlarge_21k_1k_224/.
