## Imports

In [1]:
import sys

sys.path.append("..")

import convnext

In [2]:
from tensorflow import keras
import tensorflow as tf

from imutils import paths
import json
import os

## Constants

In [3]:
AUTO = tf.data.AUTOTUNE
BATCH_SIZE = 256
IMAGE_SIZE = 224
BASE_WEIGHTS_PATH = "https://storage.googleapis.com/convnext-tf/keras-applications-temp/convnext/"

In [4]:
MODEL_CONFIGS = {
    "tiny": {
        "depths": [3, 3, 9, 3],
        "projection_dims": [96, 192, 384, 768],
        "default_size": 224,
    },
    "small": {
        "depths": [3, 3, 27, 3],
        "projection_dims": [96, 192, 384, 768],
        "default_size": 224,
    },
    "base": {
        "depths": [3, 3, 27, 3],
        "projection_dims": [128, 256, 512, 1024],
        "default_size": 224,
    },
    "large": {
        "depths": [3, 3, 27, 3],
        "projection_dims": [192, 384, 768, 1536],
        "default_size": 224,
    },
    "xlarge": {
        "depths": [3, 3, 27, 3],
        "projection_dims": [256, 512, 1024, 2048],
        "default_size": 224,
    },
}

In [5]:
WEIGHTS_HASHES = {
  "tiny":
    ("8ae6e78ce2933352b1ef4008e6dd2f17bc40771563877d156bc6426c7cf503ff",
      "d547c096cabd03329d7be5562c5e14798aa39ed24b474157cef5e85ab9e49ef1"),
  "small":
    ("ce1277d8f1ee5a0ef0e171469089c18f5233860ceaf9b168049cb9263fd7483c",
      "6fc8009faa2f00c1c1dfce59feea9b0745eb260a7dd11bee65c8e20843da6eab"),
  "base":
    ("52cbb006d3dadd03f6e095a8ca1aca47aecdd75acb4bc74bce1f5c695d0086e6",
      "40a20c5548a5e9202f69735ecc06c990e6b7c9d2de39f0361e27baeb24cb7c45"),
  "large":
    ("070c5ed9ed289581e477741d3b34beffa920db8cf590899d6d2c67fba2a198a6",
      "40a20c5548a5e9202f69735ecc06c990e6b7c9d2de39f0361e27baeb24cb7c45"),
  "xlarge":
    ("c1f5ccab661354fc3a79a10fa99af82f0fbf10ec65cb894a3ae0815f17a889ee",
      "de3f8a54174130e0cecdc71583354753d557fcf1f4487331558e2a16ba0cfe05"),
}

## Set up ImageNet-1k labels

In [6]:
with open("imagenet_class_index.json", "r") as read_file:
    imagenet_labels = json.load(read_file)

MAPPING_DICT = {}
LABEL_NAMES = {}
for label_id in list(imagenet_labels.keys()):
    MAPPING_DICT[imagenet_labels[label_id][0]] = int(label_id)
    LABEL_NAMES[int(label_id)] = imagenet_labels[label_id][1]

In [7]:
all_val_paths = list(paths.list_images("val"))
all_val_labels = [MAPPING_DICT[x.split("/")[1]] for x in all_val_paths]

all_val_paths[:5], all_val_labels[:5]

(['val/n03000134/ILSVRC2012_val_00009432.JPEG',
  'val/n03000134/ILSVRC2012_val_00018410.JPEG',
  'val/n03000134/ILSVRC2012_val_00043280.JPEG',
  'val/n03000134/ILSVRC2012_val_00041208.JPEG',
  'val/n03000134/ILSVRC2012_val_00014205.JPEG'],
 [489, 489, 489, 489, 489])

## Preprocessing utilities

In [8]:
# Model already has a normalization layer inside.
def load_and_prepare(path, label):
    image = tf.io.read_file(path)
    image = tf.image.decode_png(image, channels=3)
    image = tf.image.resize(image, (256, 256), method="bicubic")
    image = tf.image.central_crop(image, 0.875)
    return image, label

## Prepare `tf.data.Dataset`

In [9]:
dataset = tf.data.Dataset.from_tensor_slices((all_val_paths, all_val_labels))
dataset = dataset.map(load_and_prepare, num_parallel_calls=AUTO).batch(BATCH_SIZE)
dataset = dataset.prefetch(AUTO)
dataset.element_spec

2022-05-06 06:18:14.992433: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-05-06 06:18:20.470898: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 38414 MB memory:  -> device: 0, name: A100-SXM4-40GB, pci bus id: 0000:00:04.0, compute capability: 8.0


(TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name=None),
 TensorSpec(shape=(None,), dtype=tf.int32, name=None))

## Initialize models and run eval

In [10]:
for model_name in MODEL_CONFIGS:
    config = MODEL_CONFIGS.get(model_name)
    model = convnext.ConvNeXt(
        **config,
        include_top=True
    )
    checkpoint_path = os.path.join(BASE_WEIGHTS_PATH, f"convnext_{model_name}.h5")
    print(f"Fetching checkpoint from {checkpoint_path}.")
    
    file_hash = WEIGHTS_HASHES.get(model_name)[0]
    weights_path = keras.utils.get_file(
        f"convnext_{model_name}.h5",
        checkpoint_path,
        cache_subdir="models",
        file_hash=file_hash
    )
    model.load_weights(weights_path)
    model.compile(metrics=["accuracy"])
    tb_callback = tf.keras.callbacks.TensorBoard(log_dir=f"logs_{model_name}")
    
    _, accuracy = model.evaluate(dataset, callbacks=[tb_callback])
    accuracy = round(accuracy * 100, 4)
    print(f"{model_name}: {accuracy}%.", file=open(f"{model_name}.txt", "w"))

Fetching checkpoint from https://storage.googleapis.com/convnext-tf/keras-applications-temp/convnext/convnext_tiny.h5.
A local file was found, but it seems to be incomplete or outdated because the auto file hash does not match the original value of 8ae6e78ce2933352b1ef4008e6dd2f17bc40771563877d156bc6426c7cf503ff so we will re-download the data.
Downloading data from https://storage.googleapis.com/convnext-tf/keras-applications-temp/convnext/convnext_tiny.h5


2022-05-06 06:18:32.383205: I tensorflow/stream_executor/cuda/cuda_dnn.cc:366] Loaded cuDNN version 8200
2022-05-06 06:18:38.069504: I tensorflow/stream_executor/cuda/cuda_blas.cc:1774] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.


Fetching checkpoint from https://storage.googleapis.com/convnext-tf/keras-applications-temp/convnext/convnext_small.h5.
A local file was found, but it seems to be incomplete or outdated because the auto file hash does not match the original value of ce1277d8f1ee5a0ef0e171469089c18f5233860ceaf9b168049cb9263fd7483c so we will re-download the data.
Downloading data from https://storage.googleapis.com/convnext-tf/keras-applications-temp/convnext/convnext_small.h5
Fetching checkpoint from https://storage.googleapis.com/convnext-tf/keras-applications-temp/convnext/convnext_base.h5.
A local file was found, but it seems to be incomplete or outdated because the auto file hash does not match the original value of 52cbb006d3dadd03f6e095a8ca1aca47aecdd75acb4bc74bce1f5c695d0086e6 so we will re-download the data.
Downloading data from https://storage.googleapis.com/convnext-tf/keras-applications-temp/convnext/convnext_base.h5
Fetching checkpoint from https://storage.googleapis.com/convnext-tf/keras-