Main notebook for running fine-tuned vit checkpoints for various datasets

REQUIRED: Download index for checkpoints by running following in the terminal:  
    wget https://storage.googleapis.com/vit_models/augreg/index.csv  
Store it in the data folder  

Link to useful colab and github page: "How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers"  
    https://colab.research.google.com/github/google-research/vision_transformer/blob/main/vit_jax_augreg.ipynb  
    https://github.com/google-research/vision_transformer?tab=readme-ov-file  

In [44]:
import pandas as pd, numpy as np
import torch, timm
from torchvision.transforms import v2
import torchvision.datasets as datasets
import tensorflow as tf
from tqdm import tqdm
import time  # For latency measurement
from ptflops import get_model_complexity_info

## Exploring checkpoints

In [2]:
df = pd.read_csv('data/index.csv')

# all columns of index.csv dataframe
index_cols = list(df.columns)

# list of all models in the vit family
model_types = df.name.unique()

# best checkpoint filenames based on pre-train results
best_pretrains = set(
    df.query('ds=="i21k"')
    .groupby('name')
    .apply(lambda df: df.sort_values('final_val').iloc[-1], include_groups=False)
    .filename
)
# Select all finetunes from these models.
best_finetunes = df.loc[df.filename.apply(lambda filename: filename in best_pretrains)]

# all adapted datasets
adapt_datasets = best_finetunes.adapt_ds.unique()

In [3]:
print('Datasets   : ', adapt_datasets)
print('Model types: ', model_types)
print('Index cols : ', index_cols)

Datasets   :  ['imagenet2012' 'cifar100' 'resisc45' 'oxford_iiit_pet' 'kitti']
Model types:  ['Ti/16' 'S/32' 'B/16' 'L/16' 'R50+L/32' 'R26+S/32' 'S/16' 'B/32'
 'R+Ti/16' 'B/8']
Index cols :  ['name', 'ds', 'epochs', 'lr', 'aug', 'wd', 'do', 'sd', 'best_val', 'final_val', 'final_test', 'adapt_ds', 'adapt_lr', 'adapt_steps', 'adapt_resolution', 'adapt_final_val', 'adapt_final_test', 'params', 'infer_samples_per_sec', 'filename', 'adapt_filename']


In [4]:
# helper functions
def cmp_models(datasets, models_list):
    data = {'models': models_list}
    for ds in datasets:
        info = [
            best_finetunes.query(f'name=="{m}" and adapt_ds=="{ds}"')
            .sort_values('adapt_final_val') # I'm not sure whether we should sort by validation result or test result
            .iloc[-1] for m in models_list  # but original colab uses validation result
        ]
        data[f'{ds}-res'] = [int(i.adapt_resolution) for i in info]
        data[f'{ds}-val'] = [round(float(i.adapt_final_val), 5) for i in info]
        data[f'{ds}-test'] = [round(float(i.adapt_final_test), 5) for i in info]
    return pd.DataFrame(data=data)

def get_best_model(adapt_ds, model_type):
    out = (
        best_finetunes.query(f'name=="{model_type}" and adapt_ds=="{adapt_ds}"')
        .sort_values('adapt_final_val').iloc[-1].adapt_filename
    )
    return out

In [5]:
cmp_models(adapt_datasets, ['B/16', 'S/16', 'Ti/16'])

Unnamed: 0,models,imagenet2012-res,imagenet2012-val,imagenet2012-test,cifar100-res,cifar100-val,cifar100-test,resisc45-res,resisc45-val,resisc45-test,oxford_iiit_pet-res,oxford_iiit_pet-val,oxford_iiit_pet-test,kitti-res,kitti-val,kitti-test
0,B/16,384,0.89432,0.85486,224,0.94,0.9408,384,0.9773,0.97508,384,0.9837,0.94711,224,0.86525,0.81294
1,S/16,384,0.87082,0.83728,224,0.922,0.9206,384,0.97222,0.96556,384,0.96739,0.94029,384,0.85714,0.83475
2,Ti/16,384,0.81993,0.7822,224,0.888,0.8801,384,0.96698,0.96143,384,0.9538,0.91142,384,0.85106,0.83122


## Loading model and dataset

In [66]:
device = torch.device(
    "mps" if torch.backends.mps.is_available() else (
    "cuda" if torch.cuda.is_available() else
    "cpu"
))

timm_modelnames = {
    'Ti/16-224': 'vit_tiny_patch16_224',
    'Ti/16-384': 'vit_tiny_patch16_384',
    'S/16-224': 'vit_small_patch16_224',
    'S/16-384': 'vit_small_patch16_384',
    'B/16-224': 'vit_base_patch16_224',
    'B/16-384': 'vit_base_patch16_384'
}

def load_model_and_dataset(adapt_ds, model_type, batch_size):
    """
        Supports only 'cifar100' and 'oxford-iiit-pet' datasets
        because other datasets already show similar results
    """
    model_to_load = get_best_model(adapt_ds, model_type)
    print(f"Loaded checkpoint: {model_to_load}")
    # sample output: Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--cifar100-steps_10k-lr_0.003-res_224
    res = int(model_to_load.split('_')[-1])

    # load dataset
    ds_transform = v2.Compose([
            v2.ToImage(),
            v2.ToDtype(torch.float32, scale=True),
            v2.Resize((res, res)),
            v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
            # TODO: just some magic numbers for now, I need to find exact numbers that were used
            # but these are good enough
        ])
    if adapt_ds == 'cifar100':
        dataset = datasets.CIFAR100('data/', train=False, transform=ds_transform, download=True)
    else:
        dataset = datasets.OxfordIIITPet('data/', split='test', transform=ds_transform, download=True)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)


    model = timm.create_model(timm_modelnames[f'{model_type}-{res}'], num_classes=len(dataset.classes))

    # downloading a checkpoint automatically
    # may show an error, but still downloads the checkpoint
    if not tf.io.gfile.exists(f'data/{model_to_load}.npz'):
        tf.io.gfile.copy(f'gs://vit_models/augreg/{model_to_load}.npz', f'data/{model_to_load}.npz')

    timm.models.load_checkpoint(model, f'data/{model_to_load}.npz')

    model.to(device)
    model.eval()

    return model, dataset, dataloader, res

def test(model, dataloader, dataset):
    sm = torch.nn.Softmax(dim=1)
    with torch.no_grad():
        acc, correct = 0, 0
        for features, labels in tqdm(iter(dataloader)):
            features = features.to(device)
            labels = labels.to(device)
            clf = sm(model(features)).argmax(1)
            correct += (clf == labels).sum()
    acc = correct / len(dataset)
    return acc

def eval_perf(model, dataset, res):
    with torch.no_grad():
        # === Measure Latency ===
        print("Measuring latency...")
        start_time = time.time()
        model((next(iter(dataset))[0])[None].to(device))  # Measure latency for a single prediction
        end_time = time.time()
        latency = (end_time - start_time) * 1000  # Convert to milliseconds
        print(f"Latency (inference time for one sample): {latency:.2f} ms")

        # === Measure FLOPs ===
        print("Calculating FLOPs...")
        macs, params = get_model_complexity_info(model, (3, res, res),
                                             as_strings=False,
                                             print_per_layer_stat=False,
                                             verbose=False)
        print(f"GFLOPs: {(macs * 2)//1e9}\nParams: {params}")

In [108]:
batch_size = 32
ti16, dataset, dataloader, res = load_model_and_dataset('oxford_iiit_pet', 'Ti/16', batch_size)
acc_ti16 = test(ti16, dataloader, dataset)
print(f'\nTi16 accuracy {acc_ti16:.4f}')

Loaded checkpoint: Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--oxford_iiit_pet-steps_2k-lr_0.001-res_384


100%|██████████| 115/115 [00:44<00:00,  2.60it/s]


Ti16 accuracy 0.9057





In [107]:
eval_perf(ti16, dataset, res)

Measuring latency...
Latency (inference time for one sample): 12.13 ms
Calculating FLOPs...
GFLOPs: 1.0
Params: 5543716


In [95]:
batch_size = 32
b16, dataset, dataloader, res = load_model_and_dataset('cifar100', 'B/16', batch_size)
acc_b16 = test(b16, dataloader, dataset)
print(f'\nB16 accuracy {acc_b16:.4f}')

Loaded checkpoint: B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--cifar100-steps_10k-lr_0.003-res_224


100%|██████████| 313/313 [01:53<00:00,  2.76it/s]


B16 accuracy 0.9405





In [100]:
eval_perf(b16, dataset, res)

Measuring latency...
Latency (inference time for one sample): 12.48 ms
Calculating FLOPs...
GFLOPs: 24.0
Params: 85875556
