## Setup

In [None]:
!pip install -q git+https://github.com/rwightman/pytorch-image-models

  Building wheel for timm (setup.py) ... [?25l[?25hdone


In [None]:
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD

import timm
import torch
import tensorflow as tf
import tensorflow_datasets as tfds 

import pickle
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

In [None]:
BATCH_SIZE = 64 # Reduce if ResourceExhaustedError happens
AUTO = tf.data.AUTOTUNE

## Common utilities

In [None]:
def get_normalization_layer(imagenet_stats=False, scale=None, offset=None):
    if imagenet_stats:
        return tf.keras.layers.Normalization(
            mean=np.array(IMAGENET_DEFAULT_MEAN),
            variance=np.array(IMAGENET_DEFAULT_STD) ** 2
        )
    elif (scale and offset):
        return tf.keras.layers.Rescaling(
            scale=scale, offset=offset
        )
    else:
        return tf.keras.layers.Rescaling(scale=1./255)


def preprocess_image(normalization_layer):
    def f(image, label):
        if isinstance(normalization_layer, tf.keras.layers.Normalization):
            image = tf.cast(image, tf.float32) / 255.
        else:
            image = tf.cast(image, tf.float32)
        image = normalization_layer(image)
        return image, label
    return f


def get_dataset(ds_name="imagenet_a", imagenet_stats=False, resize=224,
                scale=None, offset=None):
    if imagenet_stats:
        norm_layer = get_normalization_layer(imagenet_stats)
    elif (scale and offset):
        norm_layer = get_normalization_layer(imagenet_stats, scale, offset)
    else:
        norm_layer = get_normalization_layer()
        
    imagenet_a = tfds.load(ds_name, split="test", as_supervised=True)
    imagenet_a = (
        imagenet_a
        .map(lambda x, y: (tf.image.resize(x, (resize, resize)), y))
        .batch(BATCH_SIZE)
        .map(preprocess_image(norm_layer), num_parallel_calls=True)
        .prefetch(AUTO)
    )
    return imagenet_a
    

In [None]:
# Verify 
ds = get_dataset()
image_batch, label_batch = next(iter(ds))
print(image_batch.shape, label_batch.shape)
print(tf.reduce_max(image_batch), tf.reduce_min(image_batch))

(64, 224, 224, 3) (64,)
tf.Tensor(1.0, shape=(), dtype=float32) tf.Tensor(0.0, shape=(), dtype=float32)


In [None]:
def eval_single_model(dataset, model):
    top_1 = tf.keras.metrics.SparseCategoricalAccuracy()
    top_5 = tf.keras.metrics.SparseTopKCategoricalAccuracy(k=5)

    model.eval()
    model = model.to("cuda")

    all_top_1 = []
    all_top_5 = []
    
    for image_batch, label_batch in dataset.as_numpy_iterator():
        with torch.no_grad():
            image_batch = torch.Tensor(image_batch).to("cuda")
            image_batch = image_batch.permute(0, 3, 1, 2)
            logits = model(image_batch)

        batch_accuracy_top_1 = top_1(label_batch, logits.cpu().numpy())
        batch_accuracy_top_5 = top_5(label_batch, logits.cpu().numpy())
        all_top_1.append(batch_accuracy_top_1)
        all_top_5.append(batch_accuracy_top_5)

    return np.mean(all_top_1), np.mean(all_top_5)

## AugReg model pre-trained on ImageNet-1k

Reference: https://github.com/google-research/vision_transformer/blob/main/vit_jax_augreg.ipynb

In [None]:
# This checkpoint yields the highest validation accuracy on ImageNet-1k
# with AugReg. Score: 82.7109%.
filename = "B_16-i1k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.1-sd_0.1"

vit_model = timm.create_model('vit_base_patch16_224', num_classes=1000, pretrained=False)

# Non-default checkpoints need to be loaded from local files.
if not tf.io.gfile.exists(f'{filename}.npz'):
    tf.io.gfile.copy(f'gs://vit_models/augreg/{filename}.npz', f'{filename}.npz')
timm.models.load_checkpoint(vit_model, f'{filename}.npz')

### ImageNet-A

In [None]:
ds = get_dataset(scale=1./127.5, offset=-1)
image_batch, label_batch = next(iter(ds))
print(tf.reduce_max(image_batch), tf.reduce_min(image_batch))

tf.Tensor(1.0, shape=(), dtype=float32) tf.Tensor(-1.0, shape=(), dtype=float32)


In [None]:
top_1_accs = {}
top_5_accs = {}

mean_top_1, mean_top_5 = eval_single_model(ds, vit_model)

top_1_accs.update({"vit_base_patch16_224": mean_top_1})
top_5_accs.update({"vit_base_patch16_224": mean_top_5})

top_1_accs, top_5_accs

  del sys.path[0]


({'vit_base_patch16_224': 0.08630994}, {'vit_base_patch16_224': 0.23717582})

### ImageNet-R

In [None]:
ds = get_dataset(ds_name="imagenet_r", scale=1./127.5, offset=-1)
image_batch, label_batch = next(iter(ds))
print(tf.reduce_max(image_batch), tf.reduce_min(image_batch))

tf.Tensor(1.0, shape=(), dtype=float32) tf.Tensor(-1.0, shape=(), dtype=float32)


In [None]:
top_1_accs = {}
top_5_accs = {}

mean_top_1, mean_top_5 = eval_single_model(ds, vit_model)

top_1_accs.update({"vit_base_patch16_224": mean_top_1})
top_5_accs.update({"vit_base_patch16_224": mean_top_5})

top_1_accs, top_5_accs

({'vit_base_patch16_224': 0.28213835}, {'vit_base_patch16_224': 0.4180957})

## Smaller model having resemblance to a ResNet50 w.r.t complexity

Score: 79.086%

In [None]:
filename = "S_16-i1k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0"

vit_model = timm.create_model('vit_small_patch16_224', num_classes=1000, pretrained=False)

# Non-default checkpoints need to be loaded from local files.
if not tf.io.gfile.exists(f'{filename}.npz'):
    tf.io.gfile.copy(f'gs://vit_models/augreg/{filename}.npz', f'{filename}.npz')
timm.models.load_checkpoint(vit_model, f'{filename}.npz')

### ImageNet-A

In [None]:
ds = get_dataset(scale=1./127.5, offset=-1)
image_batch, label_batch = next(iter(ds))
print(tf.reduce_max(image_batch), tf.reduce_min(image_batch))

[1mDownloading and preparing dataset imagenet_a/0.1.0 (download: 655.70 MiB, generated: 650.87 MiB, total: 1.28 GiB) to /root/tensorflow_datasets/imagenet_a/0.1.0...[0m


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]






0 examples [00:00, ? examples/s]

Shuffling and writing examples to /root/tensorflow_datasets/imagenet_a/0.1.0.incompleteT2LZZV/imagenet_a-test.tfrecord


  0%|          | 0/7500 [00:00<?, ? examples/s]

[1mDataset imagenet_a downloaded and prepared to /root/tensorflow_datasets/imagenet_a/0.1.0. Subsequent calls will reuse this data.[0m
tf.Tensor(1.0, shape=(), dtype=float32) tf.Tensor(-1.0, shape=(), dtype=float32)


In [None]:
top_1_accs = {}
top_5_accs = {}

mean_top_1, mean_top_5 = eval_single_model(ds, vit_model)

top_1_accs.update({"vit_base_patch16_224": mean_top_1})
top_5_accs.update({"vit_base_patch16_224": mean_top_5})

top_1_accs, top_5_accs

  del sys.path[0]


({'vit_base_patch16_224': 0.0639517}, {'vit_base_patch16_224': 0.1938934})

### ImageNet-R

In [None]:
ds = get_dataset(ds_name="imagenet_r", scale=1./127.5, offset=-1)
image_batch, label_batch = next(iter(ds))
print(tf.reduce_max(image_batch), tf.reduce_min(image_batch))

[1mDownloading and preparing dataset imagenet_r/0.1.0 (download: 2.04 GiB, generated: 2.03 GiB, total: 4.07 GiB) to /root/tensorflow_datasets/imagenet_r/0.1.0...[0m


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]






0 examples [00:00, ? examples/s]

Shuffling and writing examples to /root/tensorflow_datasets/imagenet_r/0.1.0.incompleteKX3JFC/imagenet_r-test.tfrecord


  0%|          | 0/30000 [00:00<?, ? examples/s]

[1mDataset imagenet_r downloaded and prepared to /root/tensorflow_datasets/imagenet_r/0.1.0. Subsequent calls will reuse this data.[0m
tf.Tensor(1.0, shape=(), dtype=float32) tf.Tensor(-1.0, shape=(), dtype=float32)


In [None]:
top_1_accs = {}
top_5_accs = {}

mean_top_1, mean_top_5 = eval_single_model(ds, vit_model)

top_1_accs.update({"vit_base_patch16_224": mean_top_1})
top_5_accs.update({"vit_base_patch16_224": mean_top_5})

top_1_accs, top_5_accs

({'vit_base_patch16_224': 0.2611397}, {'vit_base_patch16_224': 0.39939818})

## Pretraining effects with AugReg

Note that there are better models available which are essentially deeper than B/16. But to allow fair comparison we are sticking to B/16 and adaption resolution to 224. 

In [None]:
filename = "B_16-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224" # 84.018%
# B_16-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0

vit_model = timm.create_model('vit_base_patch16_224', num_classes=1000, pretrained=False)

# Non-default checkpoints need to be loaded from local files.
if not tf.io.gfile.exists(f'{filename}.npz'):
    tf.io.gfile.copy(f'gs://vit_models/augreg/{filename}.npz', f'{filename}.npz')
timm.models.load_checkpoint(vit_model, f'{filename}.npz')

### ImageNet-A

In [None]:
ds = get_dataset(scale=1./127.5, offset=-1)

In [None]:
top_1_accs = {}
top_5_accs = {}

mean_top_1, mean_top_5 = eval_single_model(ds, vit_model)

top_1_accs.update({"vit_base_patch16_224": mean_top_1})
top_5_accs.update({"vit_base_patch16_224": mean_top_5})

top_1_accs, top_5_accs

({'vit_base_patch16_224': 0.21746947}, {'vit_base_patch16_224': 0.46034816})

### ImageNet-R

In [None]:
ds = get_dataset(ds_name="imagenet_r", scale=1./127.5, offset=-1)

In [None]:
top_1_accs = {}
top_5_accs = {}

mean_top_1, mean_top_5 = eval_single_model(ds, vit_model)

top_1_accs.update({"vit_base_patch16_224": mean_top_1})
top_5_accs.update({"vit_base_patch16_224": mean_top_5})

top_1_accs, top_5_accs

({'vit_base_patch16_224': 0.41815233}, {'vit_base_patch16_224': 0.5837572})