##### Copyright 2021 Google LLC.

In [30]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

<a href="https://colab.research.google.com/github/google-research/vision_transformer/blob/master/vit_jax_augreg.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Model versions

*Adapted from original `vit_jax_augreg.ipynb`.*

Use this to find model versions and checkpoints in the original GS repository.

In [1]:
# Some more imports used in this Colab.

import glob
import os
import random
import shutil
import time

from absl import logging
import pandas as pd
import seaborn as sns
import tensorflow as tf
import tensorflow_datasets as tfds
from matplotlib import pyplot as plt

pd.options.display.max_colwidth = None
logging.set_verbosity(logging.INFO)  # Shows logs during training.

2023-05-27 21:03:54.537508: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


### Explore checkpoints

This section contains shows how to use the `index.csv` table for model
selection.

See
[`vit_jax.checkpoint.get_augreg_df()`](https://github.com/google-research/vision_transformer/blob/ed1491238f5ff6099cca81087c575a215281ed14/vit_jax/checkpoint.py#L181-L228)
for a detailed description of the individual columns

In [2]:
import os
os.environ['CURL_CA_BUNDLE'] = "/etc/ssl/certs/ca-certificates.crt"

# Load master table from Cloud.
with tf.io.gfile.GFile('gs://vit_models/augreg/index.csv') as f:
  df = pd.read_csv(f)

2023-05-27 21:03:57.951291: W tensorflow/tsl/platform/cloud/google_auth_provider.cc:184] All attempts to get a Google authentication bearer token failed, returning an empty token. Retrieving token from files failed with "NOT_FOUND: Could not locate the credentials file.". Retrieving token from GCE failed with "FAILED_PRECONDITION: Error executing an HTTP request: libcurl code 6 meaning 'Couldn't resolve host name', error details: Could not resolve host: metadata.google.internal".


In [3]:
# This is a pretty large table with lots of columns:
print(f'loaded {len(df):,} rows')
df.columns

loaded 51,509 rows


Index(['name', 'ds', 'epochs', 'lr', 'aug', 'wd', 'do', 'sd', 'best_val',
       'final_val', 'final_test', 'adapt_ds', 'adapt_lr', 'adapt_steps',
       'adapt_resolution', 'adapt_final_val', 'adapt_final_test', 'params',
       'infer_samples_per_sec', 'filename', 'adapt_filename'],
      dtype='object')

In [4]:
df.ds.unique()

array(['i1k', 'i21k', 'i21k_30'], dtype=object)

In [5]:
# Number of distinct checkpoints
len(tf.io.gfile.glob('gs://vit_models/augreg/*.npz'))

52268

In [6]:
# Any column prefixed with "adapt_" pertains to the fine-tuned checkpoints.
# Any column without that prefix pertains to the pre-trained checkpoints.
len(set(df.filename)), len(set(df.adapt_filename))

(759, 51509)

In [7]:
df.name.unique()

array(['Ti/16', 'S/32', 'B/16', 'L/16', 'R50+L/32', 'R26+S/32', 'S/16',
       'B/32', 'R+Ti/16', 'B/8'], dtype=object)

In [8]:
# Upstream AugReg parameters (section 3.3):
(
df.groupby(['ds', 'name', 'wd', 'do', 'sd', 'aug']).filename
  .count().unstack().unstack().unstack()
  .dropna(1, 'all').fillna(0).astype(int)
  .iloc[:7]  # Just show beginning of a long table.
)

  .dropna(1, 'all').fillna(0).astype(int)


Unnamed: 0_level_0,Unnamed: 1_level_0,aug,light0,light0,light1,light1,medium1,medium1,medium2,medium2,none,none,strong1,strong1,strong2,strong2
Unnamed: 0_level_1,Unnamed: 1_level_1,sd,0.0,0.1,0.0,0.1,0.0,0.1,0.0,0.1,0.0,0.1,0.0,0.1,0.0,0.1
Unnamed: 0_level_2,Unnamed: 1_level_2,do,0.0,0.1,0.0,0.1,0.0,0.1,0.0,0.1,0.0,0.1,0.0,0.1,0.0,0.1
ds,name,wd,Unnamed: 3_level_3,Unnamed: 4_level_3,Unnamed: 5_level_3,Unnamed: 6_level_3,Unnamed: 7_level_3,Unnamed: 8_level_3,Unnamed: 9_level_3,Unnamed: 10_level_3,Unnamed: 11_level_3,Unnamed: 12_level_3,Unnamed: 13_level_3,Unnamed: 14_level_3,Unnamed: 15_level_3,Unnamed: 16_level_3
i1k,B/16,0.03,68,68,68,68,68,68,68,68,68,68,68,68,68,68
i1k,B/16,0.1,68,68,68,68,68,68,68,68,68,68,68,68,68,68
i1k,B/32,0.03,68,68,68,68,68,68,68,68,68,68,68,68,68,68
i1k,B/32,0.1,68,68,68,68,68,68,68,68,68,68,68,68,68,68
i1k,L/16,0.03,68,68,68,68,68,68,68,68,68,68,68,68,68,68
i1k,L/16,0.1,68,68,68,68,68,68,68,68,68,68,68,68,68,68
i1k,R+Ti/16,0.03,68,68,68,68,68,68,68,68,68,68,68,68,68,68


In [9]:
# Downstream parameters (table 4)
# (Imbalance in 224 vs. 384 is due to recently added B/8 checkpoints)
(
df.groupby(['adapt_resolution', 'adapt_ds', 'adapt_lr', 'adapt_steps']).filename
  .count().astype(str).unstack().unstack()
  .dropna(1, 'all').fillna('')
)

  .dropna(1, 'all').fillna('')


Unnamed: 0_level_0,adapt_steps,500,500,500,500,2500,2500,2500,2500,10000,10000,10000,10000,20000,20000
Unnamed: 0_level_1,adapt_lr,0.001,0.003,0.010,0.030,0.001,0.003,0.010,0.030,0.001,0.003,0.010,0.030,0.010,0.030
adapt_resolution,adapt_ds,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2
224,cifar100,,,,,759.0,759.0,759.0,759.0,759.0,759.0,759.0,759.0,,
224,imagenet2012,,,,,,,,,,,,,759.0,759.0
224,kitti,759.0,759.0,759.0,759.0,759.0,759.0,759.0,759.0,,,,,,
224,oxford_iiit_pet,759.0,759.0,759.0,759.0,759.0,759.0,759.0,759.0,,,,,,
224,resisc45,,,,,759.0,759.0,759.0,759.0,759.0,759.0,759.0,759.0,,
384,cifar100,,,,,756.0,756.0,756.0,756.0,756.0,756.0,756.0,756.0,,
384,imagenet2012,,,,,,,,,,,,,756.0,755.0
384,kitti,756.0,756.0,756.0,756.0,756.0,756.0,756.0,756.0,,,,,,
384,oxford_iiit_pet,756.0,756.0,756.0,756.0,756.0,756.0,756.0,756.0,,,,,,
384,resisc45,,,,,756.0,756.0,756.0,756.0,756.0,756.0,756.0,756.0,,


In [28]:
# Let's first select the "best checkpoint" for every model. We show in the
# paper (section 4.5) that one can get a good performance by simply choosing the
# best model by final pre-train validation accuracy ("final-val" column).
# Pre-training with imagenet21k 300 epochs (ds=="i21k") gives the best
# performance in almost all cases (figure 6, table 5).
best_filenames = set(
    df.query('ds=="i21k"')
    .groupby('name')
    .apply(lambda df: df.sort_values('final_val').iloc[-1])
    .filename
)

# Select all finetunes from these models.
best_df = df.loc[df.filename.apply(lambda filename: filename in best_filenames)]

# Note: 9 * 68 == 612
len(best_filenames), len(best_df)

(10, 646)

In [11]:
best_df.columns

Index(['name', 'ds', 'epochs', 'lr', 'aug', 'wd', 'do', 'sd', 'best_val',
       'final_val', 'final_test', 'adapt_ds', 'adapt_lr', 'adapt_steps',
       'adapt_resolution', 'adapt_final_val', 'adapt_final_test', 'params',
       'infer_samples_per_sec', 'filename', 'adapt_filename'],
      dtype='object')

In [12]:
best_df.adapt_ds.unique()

array(['imagenet2012', 'cifar100', 'resisc45', 'oxford_iiit_pet', 'kitti'],
      dtype=object)

## Use this code to get the model weights of the finetuned models

In [56]:
# Note that this dataframe contains the models from the "i21k_300" column of
# table 3:
# print_df = best_df.query('adapt_ds=="imagenet2012"').groupby('name').apply(
# print_df = best_df.query('adapt_ds=="cifar100"').groupby('name').apply(
print_df = best_df.query('adapt_ds=="oxford_iiit_pet"').groupby('name').apply(
    lambda df: df.sort_values('adapt_final_val').iloc[-1]
)[[
   # Columns from upstream
#    'name', 'ds', 'filename',
   'name', 'ds',
   # Columns from downstream
   'adapt_resolution', 'infer_samples_per_sec', 'adapt_ds', 'adapt_final_test', 'adapt_filename',
]].rename(columns={
    'ds': 'pretrained',
    'name': 'model_version',
    'adapt_resolution': 'img_size',
    'adapt_ds': 'dataset',
    'adapt_final_test': 'reported_val_acc',
    'adapt_filename': 'filename',
}).replace({
    'i21k': 'imagenet21k',
    'i1k': 'imagenet',
    'imagenet2012': 'imagenet',
    'oxford_iiit_pet': 'oxfordpets',
})
print_df['identifier'] = print_df.model_version + '-' + print_df.img_size.astype(str) + '-' + print_df.pretrained.astype(str) + '-' + print_df.dataset.astype(str)
print_df = print_df.sort_values('infer_samples_per_sec')
print_df
# Print print_df as a dictionary of dictionaries, using model_version as key
# print_df.set_index('identifier').to_dict(orient='index')                key='train_sample', images=[image]


Unnamed: 0_level_0,model_version,pretrained,img_size,infer_samples_per_sec,dataset,reported_val_acc,filename,identifier
name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
L/16,L/16,imagenet,384,49.87,oxfordpets,0.928571,L_16-i1k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.1-sd_0.1--oxford_iiit_pet-steps_2k-lr_0.01-res_384,L/16-384-imagenet-oxfordpets
B/16,B/16,imagenet,384,137.92,oxfordpets,0.943566,B_16-i1k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.1-sd_0.1--oxford_iiit_pet-steps_0k-lr_0.01-res_384,B/16-384-imagenet-oxfordpets
S/16,S/16,imagenet,384,300.12,oxfordpets,0.940294,S_16-i1k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--oxford_iiit_pet-steps_0k-lr_0.01-res_384,S/16-384-imagenet-oxfordpets
R50+L/32,R50+L/32,imagenet,384,326.73,oxfordpets,0.93675,R50_L_32-i1k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.1-sd_0.1--oxford_iiit_pet-steps_2k-lr_0.01-res_384,R50+L/32-384-imagenet-oxfordpets
R26+S/32,R26+S/32,imagenet,384,560.4,oxfordpets,0.943309,R26_S_32-i1k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.1-sd_0.1--oxford_iiit_pet-steps_0k-lr_0.01-res_384,R26+S/32-384-imagenet-oxfordpets
Ti/16,Ti/16,imagenet,384,609.58,oxfordpets,0.927501,Ti_16-i1k-300ep-lr_0.001-aug_light0-wd_0.1-do_0.0-sd_0.0--oxford_iiit_pet-steps_2k-lr_0.001-res_384,Ti/16-384-imagenet-oxfordpets
B/32,B/32,imagenet,384,954.94,oxfordpets,0.926663,B_32-i1k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.1-sd_0.1--oxford_iiit_pet-steps_0k-lr_0.03-res_384,B/32-384-imagenet-oxfordpets
S/32,S/32,imagenet,384,2153.94,oxfordpets,0.920959,S_32-i1k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--oxford_iiit_pet-steps_2k-lr_0.003-res_384,S/32-384-imagenet-oxfordpets
R+Ti/16,R+Ti/16,imagenet,384,2425.77,oxfordpets,0.914963,R_Ti_16-i1k-300ep-lr_0.001-aug_light0-wd_0.1-do_0.0-sd_0.0--oxford_iiit_pet-steps_0k-lr_0.003-res_384,R+Ti/16-384-imagenet-oxfordpets


## Use this code to get the pretrained models VVV

In [29]:
# Note that this dataframe contains the models from the "i21k_300" column of
# table 3:
print_df = best_df.query('adapt_ds=="imagenet2012"').groupby('name').apply(
# print_df = best_df.query('adapt_ds=="cifar10"').groupby('name').apply(
# print_df = best_df.query('adapt_ds=="oxford_iiit_pet"').groupby('name').apply(
    lambda df: df.sort_values('adapt_final_val').iloc[-1]
)[[
   # Columns from upstream
   'name', 'ds', 'filename',
]].rename(columns={
    'name': 'model_version',
    'ds': 'pretrained',
}).replace({
    'i21k': 'imagenet21k',
    'i1k': 'imagenet',
    'imagenet2012': 'imagenet',
    'oxford_iiit_pet': 'oxfordpets',
})
print_df['img_size'] = 224
print_df['identifier'] = print_df.model_version + '-224-' + print_df.pretrained.astype(str)
print_df
# Print print_df as a dictionary of dictionaries, using model_version as key
print_df.set_index('identifier').to_dict(orient='index')

{'B/16-224-imagenet21k': {'model_version': 'B/16',
  'pretrained': 'imagenet21k',
  'filename': 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0',
  'img_size': 224},
 'B/32-224-imagenet21k': {'model_version': 'B/32',
  'pretrained': 'imagenet21k',
  'filename': 'B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0',
  'img_size': 224},
 'B/8-224-imagenet21k': {'model_version': 'B/8',
  'pretrained': 'imagenet21k',
  'filename': 'B_8-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0',
  'img_size': 224},
 'L/16-224-imagenet21k': {'model_version': 'L/16',
  'pretrained': 'imagenet21k',
  'filename': 'L_16-i21k-300ep-lr_0.001-aug_strong1-wd_0.1-do_0.0-sd_0.0',
  'img_size': 224},
 'R+Ti/16-224-imagenet21k': {'model_version': 'R+Ti/16',
  'pretrained': 'imagenet21k',
  'filename': 'R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0',
  'img_size': 224},
 'R26+S/32-224-imagenet21k': {'model_version': 'R26+S/32',
  'pretrained': 'imagenet21k',
  'filename': 'R26_S_3