## Imports

In [1]:
import sys

sys.path.append("..")

import convnext

In [2]:
from tensorflow import keras
import tensorflow as tf

from imutils import paths
import json
import os

## Constants

In [3]:
AUTO = tf.data.AUTOTUNE
BATCH_SIZE = 256
IMAGE_SIZE = 224
BASE_WEIGHTS_PATH = "https://storage.googleapis.com/convnext-tf/keras-applications/convnext/"

In [4]:
MODEL_CONFIGS = {
    "tiny": {
        "depths": [3, 3, 9, 3],
        "projection_dims": [96, 192, 384, 768],
        "default_size": 224,
    },
    "small": {
        "depths": [3, 3, 27, 3],
        "projection_dims": [96, 192, 384, 768],
        "default_size": 224,
    },
    "base": {
        "depths": [3, 3, 27, 3],
        "projection_dims": [128, 256, 512, 1024],
        "default_size": 224,
    },
    "large": {
        "depths": [3, 3, 27, 3],
        "projection_dims": [192, 384, 768, 1536],
        "default_size": 224,
    },
    "xlarge": {
        "depths": [3, 3, 27, 3],
        "projection_dims": [256, 512, 1024, 2048],
        "default_size": 224,
    },
}

In [5]:
WEIGHTS_HASHES = {
  "tiny":
    ("dec324e40ebe943afc7b75b72484646eeb092c04bb079df35911d7080364f9a8",
      "4d4f0e079db2cc0e627b55f7d0d76c367145d14f2c90674415373457cd822346"),
  "small":
    ("f964ea5cd5618a1e64902a74ca5ccff3797a4fa5dba11a14f2c4d1a562b72f08",
      "fd8f0ac74faa4e364d7cb5b2d32af9ae35b54ce5e80525b5beb7b7571320065a"),
  "base":
    ("d30e0c509f4e1abe2784d33765d4391ce8fbff259b0bd79f4a63684b20db87d2",
      "736f7a96cd933ee568611e29f334737fb9aebaaea021ea7adfe4d2f5cbb4a9aa"),
  "large":
    ("8a304c66deb782b0d59837bc13127068901adaaa280cfac604d3341aaf44b2cf",
      "b02b623b3c28586423e6be4aa214e2f5619280b97b4ef6b35ffb686e83235f01"),
  "xlarge":
    ("da65d1294d386c71aebd81bc2520b8d42f7f60eee4414806c60730cd63eb15cb",
      "2bfbf5f0c2b3f004f1c32e9a76661e11a9ac49014ed2a68a49ecd0cd6c88d377"),
}

## Set up ImageNet-1k labels

In [6]:
with open("imagenet_class_index.json", "r") as read_file:
    imagenet_labels = json.load(read_file)

MAPPING_DICT = {}
LABEL_NAMES = {}
for label_id in list(imagenet_labels.keys()):
    MAPPING_DICT[imagenet_labels[label_id][0]] = int(label_id)
    LABEL_NAMES[int(label_id)] = imagenet_labels[label_id][1]

In [7]:
all_val_paths = list(paths.list_images("val"))
all_val_labels = [MAPPING_DICT[x.split("/")[1]] for x in all_val_paths]

all_val_paths[:5], all_val_labels[:5]

(['val/n03000134/ILSVRC2012_val_00009432.JPEG',
  'val/n03000134/ILSVRC2012_val_00018410.JPEG',
  'val/n03000134/ILSVRC2012_val_00043280.JPEG',
  'val/n03000134/ILSVRC2012_val_00041208.JPEG',
  'val/n03000134/ILSVRC2012_val_00014205.JPEG'],
 [489, 489, 489, 489, 489])

## Preprocessing utilities

In [8]:
# Model already has a normalization layer inside.
def load_and_prepare(path, label):
    image = tf.io.read_file(path)
    image = tf.image.decode_png(image, channels=3)
    image = tf.image.resize(image, (256, 256), method="bicubic")
    image = tf.image.central_crop(image, 0.875)
    return image, label

## Prepare `tf.data.Dataset`

In [9]:
dataset = tf.data.Dataset.from_tensor_slices((all_val_paths, all_val_labels))
dataset = dataset.map(load_and_prepare, num_parallel_calls=AUTO).batch(BATCH_SIZE)
dataset = dataset.prefetch(AUTO)
dataset.element_spec

2022-04-17 11:39:00.710234: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-04-17 11:39:01.283057: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 38414 MB memory:  -> device: 0, name: A100-SXM4-40GB, pci bus id: 0000:00:04.0, compute capability: 8.0


(TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name=None),
 TensorSpec(shape=(None,), dtype=tf.int32, name=None))

## Initialize models and run eval

In [11]:
for model_name in MODEL_CONFIGS:
    config = MODEL_CONFIGS.get(model_name)
    model = convnext.ConvNeXt(
        **config,
        include_top=True
    )
    checkpoint_path = os.path.join(BASE_WEIGHTS_PATH, f"convnext_{model_name}.h5")
    print(f"Fetching checkpoint from {checkpoint_path}.")
    
    file_hash = WEIGHTS_HASHES.get(model_name)[0]
    weights_path = keras.utils.get_file(
        f"convnext_{model_name}.h5",
        checkpoint_path,
        cache_subdir="models",
        file_hash=file_hash
    )
    model.load_weights(weights_path)
    model.compile(metrics=["accuracy"])
    tb_callback = tf.keras.callbacks.TensorBoard(log_dir=f"logs_{model_name}")
    
    _, accuracy = model.evaluate(dataset, callbacks=[tb_callback])
    accuracy = round(accuracy * 100, 4)
    print(f"{model_name}: {accuracy}%.", file=open(f"{model_name}.txt", "w"))

Fetching checkpoint from https://storage.googleapis.com/convnext-tf/keras-applications/convnext/convnext_tiny.h5.
Downloading data from https://storage.googleapis.com/convnext-tf/keras-applications/convnext/convnext_tiny.h5


2022-04-17 11:39:21.375490: I tensorflow/stream_executor/cuda/cuda_dnn.cc:366] Loaded cuDNN version 8200
2022-04-17 11:39:27.167923: I tensorflow/stream_executor/cuda/cuda_blas.cc:1774] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.


Fetching checkpoint from https://storage.googleapis.com/convnext-tf/keras-applications/convnext/convnext_small.h5.
Downloading data from https://storage.googleapis.com/convnext-tf/keras-applications/convnext/convnext_small.h5
Fetching checkpoint from https://storage.googleapis.com/convnext-tf/keras-applications/convnext/convnext_base.h5.
Downloading data from https://storage.googleapis.com/convnext-tf/keras-applications/convnext/convnext_base.h5
Fetching checkpoint from https://storage.googleapis.com/convnext-tf/keras-applications/convnext/convnext_large.h5.
Downloading data from https://storage.googleapis.com/convnext-tf/keras-applications/convnext/convnext_large.h5
Fetching checkpoint from https://storage.googleapis.com/convnext-tf/keras-applications/convnext/convnext_xlarge.h5.
Downloading data from https://storage.googleapis.com/convnext-tf/keras-applications/convnext/convnext_xlarge.h5
