# Selecting a pre-trained checkpoint

Google has provided around 759 pre-trained checkpoints (trained on different ImageNet datasets).

These pre-trained checkpoints are used to perform some downstream tasks (classification on other datasets, e.g. `imagenet2012`, `cifar100`, `resisc45`, `oxford_iiit_pet`, `kitti`). Google has provided 52K fine-tuned checkpoints as well.

In order to select the right pre-trained checkpoint for experimentation, we will shortlist in the following order

1. Best pre-training dataset (we force select ImageNet 21K, because of reasons explained below).
2. Options to give the best final validation accuracy on pre-training.
3. Model(s) that can be fine-tuned in a feasible amount of time.

In [1]:
import pandas as pd
import tensorflow as tf

pd.options.display.max_colwidth = None

2022-12-05 18:09:42.139037: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-12-05 18:09:42.254527: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


### Load database

We load the entire list of 52K checkpoints

In [2]:
with tf.io.gfile.GFile('gs://vit_models/augreg/index.csv') as f:
    df = pd.read_csv(f)

df.head()

2022-12-05 18:09:43.959117: W tensorflow/core/platform/cloud/google_auth_provider.cc:184] All attempts to get a Google authentication bearer token failed, returning an empty token. Retrieving token from files failed with "NOT_FOUND: Could not locate the credentials file.". Retrieving token from GCE failed with "FAILED_PRECONDITION: Error executing an HTTP request: libcurl code 6 meaning 'Couldn't resolve host name', error details: Could not resolve host: metadata".


Unnamed: 0,name,ds,epochs,lr,aug,wd,do,sd,best_val,final_val,...,adapt_ds,adapt_lr,adapt_steps,adapt_resolution,adapt_final_val,adapt_final_test,params,infer_samples_per_sec,filename,adapt_filename
0,Ti/16,i1k,300.0,0.001,light0,0.03,0.1,0.1,0.702544,0.702232,...,imagenet2012,0.03,20000,384,0.755698,0.72874,5790000.0,609.58,Ti_16-i1k-300ep-lr_0.001-aug_light0-wd_0.03-do_0.1-sd_0.1,Ti_16-i1k-300ep-lr_0.001-aug_light0-wd_0.03-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.03-res_384
1,Ti/16,i1k,300.0,0.001,light0,0.03,0.1,0.1,0.702544,0.702232,...,imagenet2012,0.01,20000,384,0.754605,0.72412,5790000.0,609.58,Ti_16-i1k-300ep-lr_0.001-aug_light0-wd_0.03-do_0.1-sd_0.1,Ti_16-i1k-300ep-lr_0.001-aug_light0-wd_0.03-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384
2,Ti/16,i1k,300.0,0.001,light0,0.03,0.1,0.1,0.702544,0.702232,...,cifar100,0.03,10000,384,0.836,0.8338,5790000.0,609.58,Ti_16-i1k-300ep-lr_0.001-aug_light0-wd_0.03-do_0.1-sd_0.1,Ti_16-i1k-300ep-lr_0.001-aug_light0-wd_0.03-do_0.1-sd_0.1--cifar100-steps_10k-lr_0.03-res_384
3,Ti/16,i1k,300.0,0.001,light0,0.03,0.1,0.1,0.702544,0.702232,...,cifar100,0.01,10000,384,0.835,0.8304,5790000.0,609.58,Ti_16-i1k-300ep-lr_0.001-aug_light0-wd_0.03-do_0.1-sd_0.1,Ti_16-i1k-300ep-lr_0.001-aug_light0-wd_0.03-do_0.1-sd_0.1--cifar100-steps_10k-lr_0.01-res_384
4,Ti/16,i1k,300.0,0.001,light0,0.03,0.1,0.1,0.702544,0.702232,...,cifar100,0.003,10000,384,0.8,0.7962,5790000.0,609.58,Ti_16-i1k-300ep-lr_0.001-aug_light0-wd_0.03-do_0.1-sd_0.1,Ti_16-i1k-300ep-lr_0.001-aug_light0-wd_0.03-do_0.1-sd_0.1--cifar100-steps_10k-lr_0.003-res_384


### Different (pre-)training configurations

Below is the number of different pre-trained model configurations. These determine the `filename` attribute.

The format of the `filename` is as below:

`ViT model type`-`pretraining dataset`-`number of pretraining epochs`-`learning rate`-`amount of data augmentation`-`wd`-`do`-`sd`

And an example filename is:

`S_32-i21k-300ep-lr_0.001-aug_none-wd_0.1-do_0.0-sd_0.0`

### Some statistics

Below are some numbers to put things in perspective.

In [3]:
print(f"Available model types: {df.name.unique()}")

print(f"Number of pre-trained checkpoints: {len(df.filename.unique())}")

print(f"Fine tuning datasets: {df.adapt_ds.unique()}")

Available model types: ['Ti/16' 'S/32' 'B/16' 'L/16' 'R50+L/32' 'R26+S/32' 'S/16' 'B/32'
 'R+Ti/16' 'B/8']
Number of pre-trained checkpoints: 759
Fine tuning datasets: ['imagenet2012' 'cifar100' 'resisc45' 'oxford_iiit_pet' 'kitti']


### Shortlisting checkpoints

Instead of directly selecting the parent checkpoint of the best (fine-tuned) checkpoint for **CIFAR 100**, we first shortlist checkpoints on the basis of the best pre-training dataset.

In the paper (section 4.5) they mention that we can get a good performance by simply choosing the best model by final pre-train validation accuracy (`final_val` column).

Pre-training with `i21k` (imagenet21k) gives the best performance in almost all cases (figure 6, table 5).

We therefore restrict our search to only `i21k` pre-trained checkpoints, as it looks like it generalizes better on downstream tasks.

This reduces our pretrained checkpoints from 759 to just 10.

In [4]:
best_filenames = set(
    df.query('ds=="i21k"')
    .groupby('name')
    .apply(lambda df: df.sort_values('final_val').iloc[-1])
    .filename
)

best_df = df.loc[df.filename.apply(lambda filename: filename in best_filenames)]

pd.DataFrame(best_df.filename.unique())

Unnamed: 0,0
0,R26_S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0
1,S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0
2,B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0
3,B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0
4,L_16-i21k-300ep-lr_0.001-aug_strong1-wd_0.1-do_0.0-sd_0.0
5,R50_L_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1
6,Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0
7,S_32-i21k-300ep-lr_0.001-aug_none-wd_0.1-do_0.0-sd_0.0
8,R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0
9,B_8-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0


### Display performances on CIFAR 100

We have already shortlisted the best pre-trained checkpoint for each model type on the basis of final validation accuracy on ImageNet 21K pre-training.

For the sake of visualizaiton, below we print the performance of each model type on **CIFAR 100** downstream task.

For example, the **R+Ti/16** model type has a test accuracy (i.e. `adapt_final_test`) of 85.95% on **CIFAR 100** using the resolution of 224 pixels.

In [5]:
best_df.query('adapt_ds=="cifar100"').groupby(['name', 'adapt_resolution']).apply(
    lambda df: df.sort_values('adapt_final_test').iloc[-1]
)[[
   # Columns from upstream
   'name', 'params', 'ds', 'filename',
   # Columns from downstream
   'adapt_resolution', 'infer_samples_per_sec','adapt_ds', 'adapt_final_test', 'adapt_filename',
]].sort_values('params')

Unnamed: 0_level_0,Unnamed: 1_level_0,name,params,ds,filename,adapt_resolution,infer_samples_per_sec,adapt_ds,adapt_final_test,adapt_filename
name,adapt_resolution,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
Ti/16,224,Ti/16,5720000.0,i21k,Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0,224,3097.42,cifar100,0.8801,Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--cifar100-steps_10k-lr_0.003-res_224
Ti/16,384,Ti/16,5790000.0,i21k,Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0,384,609.58,cifar100,0.8746,Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--cifar100-steps_10k-lr_0.003-res_384
R+Ti/16,224,R+Ti/16,6340000.0,i21k,R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0,224,9371.0,cifar100,0.8595,R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--cifar100-steps_10k-lr_0.01-res_224
R+Ti/16,384,R+Ti/16,6360000.0,i21k,R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0,384,2425.77,cifar100,0.8578,R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--cifar100-steps_10k-lr_0.01-res_384
S/16,224,S/16,22050000.0,i21k,S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0,224,1508.35,cifar100,0.9209,S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--cifar100-steps_10k-lr_0.003-res_224
S/16,384,S/16,22200000.0,i21k,S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0,384,300.12,cifar100,0.92,S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--cifar100-steps_2k-lr_0.03-res_384
S/32,224,S/32,22880000.0,i21k,S_32-i21k-300ep-lr_0.001-aug_none-wd_0.1-do_0.0-sd_0.0,224,8342.46,cifar100,0.9055,S_32-i21k-300ep-lr_0.001-aug_none-wd_0.1-do_0.0-sd_0.0--cifar100-steps_10k-lr_0.003-res_224
S/32,384,S/32,22920000.0,i21k,S_32-i21k-300ep-lr_0.001-aug_none-wd_0.1-do_0.0-sd_0.0,384,2153.94,cifar100,0.9046,S_32-i21k-300ep-lr_0.001-aug_none-wd_0.1-do_0.0-sd_0.0--cifar100-steps_10k-lr_0.003-res_384
R26+S/32,224,R26+S/32,36430000.0,i21k,R26_S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0,224,1814.25,cifar100,0.9244,R26_S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--cifar100-steps_10k-lr_0.003-res_224
R26+S/32,384,R26+S/32,36470000.0,i21k,R26_S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0,384,560.4,cifar100,0.9242,R26_S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--cifar100-steps_10k-lr_0.003-res_384


### Select models to experiment with

We can see that the fastest inference times are for `R+Ti/16` and `S/32`, each with using 224 as the fine-tuning resolution.

They are also the fastest models to train.

We therefore restrict our ablation study to just these two models.

### Explore the parameters in the checkpoint

The below snippet allows us to explore the checkpoint identified by the filename.

In [3]:
import tensorflow as tf
from vit_jax import checkpoint

filename = "R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--cifar100-steps_10k-lr_0.01-res_224"

path = f'gs://vit_models/augreg/{filename}.npz'

# Non-default checkpoints need to be loaded from local files.
if not tf.io.gfile.exists(f'{filename}.npz'):
  tf.io.gfile.copy(path, f'{filename}.npz')
params = checkpoint.load(path)

In [4]:
params.keys()

dict_keys(['Transformer', 'cls', 'conv_root', 'embedding', 'gn_root', 'head'])

In [5]:
params['Transformer'].keys()

dict_keys(['encoder_norm', 'encoderblock_0', 'encoderblock_1', 'encoderblock_2', 'encoderblock_3', 'encoderblock_4', 'encoderblock_5', 'encoderblock_6', 'encoderblock_7', 'encoderblock_8', 'encoderblock_9', 'encoderblock_10', 'encoderblock_11', 'posembed_input'])

In [6]:
params['Transformer']['encoderblock_0']['MultiHeadDotProductAttention_0'].keys()

dict_keys(['key', 'out', 'query', 'value'])

In [7]:
params['Transformer']['encoderblock_0']['MultiHeadDotProductAttention_0']['query']['bias'].shape

(3, 64)