## Imports

Suppress TensorFlow warnings.

In [None]:
# Copied from:
# https://weepingfish.github.io/2020/07/22/0722-suppress-tensorflow-warnings/

# Filter tensorflow version warnings
import os

# https://stackoverflow.com/questions/40426502/is-there-a-way-to-suppress-the-messages-tensorflow-prints/40426709
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"  # or any {'0', '1', '2'}
import warnings

# https://stackoverflow.com/questions/15777951/how-to-suppress-pandas-future-warning
warnings.simplefilter(action="ignore", category=FutureWarning)
warnings.simplefilter(action="ignore", category=Warning)
import tensorflow as tf

tf.get_logger().setLevel("INFO")
tf.autograph.set_verbosity(0)
import logging

tf.get_logger().setLevel(logging.ERROR)

In [None]:
from tensorflow import keras

from torchvision.datasets import ImageFolder
from torchvision import transforms
from torch.utils.data import DataLoader

from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD

## Constants

In [None]:
# Change batch size accordingly in case of OOM.
# Change the image size to 384 wheb evaluation's done on 224.
BATCH_SIZE = 256
IMAGE_SIZE = 224
TF_MODEL_ROOT = "gs://swin-tf"

## Swin models 

In [None]:
model_paths = tf.io.gfile.listdir(TF_MODEL_ROOT)
model_paths = [p for p in model_paths if str(IMAGE_SIZE) in p and "fe" not in p and "22k" not in p]
print(model_paths)

## Image loader

To have an apples-to-apples comparison with the original PyTorch models for evaluation, it's important to ensure we use the same transformations.

In [None]:
# Transformations from:
# (1) https://github.com/microsoft/Swin-Transformer
# (2) https://github.com/microsoft/Swin-Transformer/tree/main/data

if IMAGE_SIZE == 224:
    size = int((256 / 224) * IMAGE_SIZE)
    transform_chain = transforms.Compose(
        [
            transforms.Resize(size, interpolation=3),
            transforms.CenterCrop(IMAGE_SIZE),
            transforms.ToTensor(),
            transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
        ]
    )
else:
    transform_chain = transforms.Compose(
        [
            transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), interpolation=3),
            transforms.ToTensor(),
            transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
        ]
    )

In [None]:
dataset = ImageFolder("val", transform=transform_chain)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=6)

batch = next(iter(dataloader))
print(batch[0].shape)

## Run evaluation

In [None]:
def get_model(model_url):
    model = keras.models.load_model(model_url)
    return model

In [None]:
# Copied and modified from:
# https://github.com/sebastian-sz/resnet-rs-keras/blob/main/imagenet_evaluation/main.py

log_file = f"swin_{IMAGE_SIZE}_in1k.csv"

if not os.path.exists(log_file):
    with open(log_file, 'w') as f:
        f.write(
            'model_name,top1_acc(%),top5_acc(%)\n'
        )

for path in model_paths:
    print(f"Evaluating {path}.")
    model = get_model(f"{TF_MODEL_ROOT}/{path.strip('/')}")

    top1 = tf.keras.metrics.SparseTopKCategoricalAccuracy(k=1, name="top1")
    top5 = tf.keras.metrics.SparseTopKCategoricalAccuracy(k=5, name="top5")
    progbar = tf.keras.utils.Progbar(target=len(dataset) // BATCH_SIZE)

    for idx, (images, y_true) in enumerate(dataloader):
        images = images.numpy().transpose(0, 2, 3, 1)
        y_true = y_true.numpy()
        y_pred = model.predict(images)

        top1.update_state(y_true=y_true, y_pred=y_pred)
        top5.update_state(y_true=y_true, y_pred=y_pred)

        progbar.update(
            idx, [("top1", top1.result().numpy()), ("top5", top5.result().numpy())]
        )

    print()
    print(f"TOP1: {top1.result().numpy()}.  TOP5: {top5.result().numpy()}")
    
    top_1 = top1.result().numpy() * 100.
    top_5 = top5.result().numpy() * 100.
    with open(log_file, 'a') as f:
        f.write("%s,%0.3f,%0.3f\n" % (path, top_1, top_5))

In [None]:
!sudo shutdown now