In [None]:
import tensorflow_datasets as tfds
import tensorflow_hub as hub
import tensorflow as tf

from imutils import paths
import json

In [None]:
try:
    tpu = None
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.TPUStrategy(tpu)
except ValueError:
    strategy = tf.distribute.MirroredStrategy()

print("Number of accelerators: ", strategy.num_replicas_in_sync)

In [None]:
AUTO = tf.data.AUTOTUNE
BATCH_SIZE = 128 * strategy.num_replicas_in_sync

In [None]:
with open("imagenet_class_index.json", "r") as read_file:
    imagenet_labels = json.load(read_file)

MAPPING_DICT = {}
LABEL_NAMES = {}
for label_id in list(imagenet_labels.keys()):
    MAPPING_DICT[imagenet_labels[label_id][0]] = int(label_id)
    LABEL_NAMES[int(label_id)] = imagenet_labels[label_id][1]

In [None]:
all_val_paths = list(paths.list_images("val"))
all_val_labels = [MAPPING_DICT[x.split("/")[1]] for x in all_val_paths]

all_val_paths[:5], all_val_labels[:5]

In [None]:
def load_and_prepare(path, label):
    image = tf.io.read_file(path)
    image = tf.image.decode_png(image, channels=3)
    image = tf.image.resize(image, (224, 224))

    return image, label

In [None]:
dataset = tf.data.Dataset.from_tensor_slices((all_val_paths, all_val_labels))

dataset = dataset.map(load_and_prepare, num_parallel_calls=AUTO).batch(BATCH_SIZE)
dataset = dataset.prefetch(AUTO)

In [None]:
def get_model(model_url="https://tfhub.dev/sayakpaul/vit_s16_classification/1"):
    classification_model = tf.keras.Sequential(
        [
            tf.keras.layers.InputLayer((224, 224, 3)),
            tf.keras.layers.Rescaling(
                scale=1.0 / 127.5, offset=-1
            ),  # Scales to [-1, 1].
            hub.KerasLayer(model_url),
        ]
    )
    return classification_model

In [None]:
def eval_util(model_url, arch):
    tb_callback = tf.keras.callbacks.TensorBoard(log_dir=f"logs_{arch}")
    with strategy.scope():
        model = get_model(model_url)
    model.compile(metrics=["accuracy"])
    model.evaluate(dataset, callbacks=[tb_callback])

In [None]:
model_urls = [
    "https://tfhub.dev/sayakpaul/vit_s16_classification/1",
    "https://tfhub.dev/sayakpaul/vit_r26_s32_lightaug_classification/1",
]

archs = ["s16", "r26_s32"]

for model_url, arch in zip(model_urls, archs):
    print(f"Evaluating {arch}")
    eval_util(model_url, arch)