## Vision transformers

Outline:
- Fill in model architecture (maybe ask them to implement attention by hand? might be too hard)
- Load pretrained weights
- Load ImageNet
- Fine-tune on ImageNet (write training loop)

## Vision transformer architecture

First, go to vit_architecture.py and fill in some model definitions. You might find this illustration of the architecture of a vision transformer helpful (Figure 1 from the [ViT paper](https://arxiv.org/pdf/2010.11929.pdf)).

![](vit_architecture.png)

Once you complete all the TODOs, run the cell below to import the model.

In [1]:
from vit_architecture_solutions import VisionTransformer
# from vit_architecture import VisionTransformer

config = {
    "num_classes": 10,
    "patch_size": 16,
    "hidden_size": 768,
    "model_name": 'ViT-S_16',
}

transformer_config = {
    "mlp_dim": 3072,
    "num_heads": 12,
    "num_layers": 12,
    "attn_dropout": 0.0,
    "dropout": 0.0,
}

model = VisionTransformer(
    num_classes=config["num_classes"], 
    patch_size=config["patch_size"],
    hidden_size=config["hidden_size"],
    model_name=config["model_name"],
    transformer_config=transformer_config,
)

model

VisionTransformer(
    # attributes
    num_classes = 10
    patch_size = 16
    hidden_size = 768
    transformer_config = {'mlp_dim': 3072, 'num_heads': 12, 'num_layers': 12, 'attn_dropout': 0.0, 'dropout': 0.0}
    cls_head_bias_init = 0.0
    model_name = 'ViT-B_16'
)

## Load pre-trained weights

In [5]:
import os
import numpy as np
import jax.numpy as jnp
from flax.training import checkpoints
from vit_utils import recover_tree

ModuleNotFoundError: No module named 'tensorflow'

In [3]:
checkpoint_file = "S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz"
if not os.path.exists("checkpoints"):
    os.system("mkdir checkpoints")
os.chdir("checkpoints")
os.system(f"wget https://storage.googleapis.com/vit_models/augreg/{checkpoint_file}")
os.chdir("..")
with open(f"checkpoints/{checkpoint_file}", "rb") as f:
    ckpt_dict = np.load(f, allow_pickle=False)
    keys, values = zip(*list(ckpt_dict.items()))

--2022-11-30 01:24:08--  https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz
Resolving storage.googleapis.com (storage.googleapis.com)... 142.251.32.240, 142.251.32.144, 142.250.68.176, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|142.251.32.240|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 120365794 (115M) [application/octet-stream]
Saving to: ‘S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz’

     0K .......... .......... .......... .......... ..........  0% 10.6M 11s
    50K .......... .......... .......... .......... ..........  0% 9.81M 11s
   100K .......... .......... .......... .......... ..........  0% 10.8M 11s
   150K .......... .......... .......... .......... ..........  0% 11.5M 11s
   200K .......... .......... .......... .......... ..........  0% 11.9M 11s
   250K .......... .......... .......... .......... ..........  0% 11.6M 10s
   300K ..........

In [4]:
params = checkpoints.convert_pre_linen(recover_tree(keys, values))

('Transformer/encoder_norm/bias',
 'Transformer/encoder_norm/scale',
 'Transformer/encoderblock_0/LayerNorm_0/bias',
 'Transformer/encoderblock_0/LayerNorm_0/scale',
 'Transformer/encoderblock_0/LayerNorm_2/bias',
 'Transformer/encoderblock_0/LayerNorm_2/scale',
 'Transformer/encoderblock_0/MlpBlock_3/Dense_0/bias',
 'Transformer/encoderblock_0/MlpBlock_3/Dense_0/kernel',
 'Transformer/encoderblock_0/MlpBlock_3/Dense_1/bias',
 'Transformer/encoderblock_0/MlpBlock_3/Dense_1/kernel',
 'Transformer/encoderblock_0/MultiHeadDotProductAttention_1/key/bias',
 'Transformer/encoderblock_0/MultiHeadDotProductAttention_1/key/kernel',
 'Transformer/encoderblock_0/MultiHeadDotProductAttention_1/out/bias',
 'Transformer/encoderblock_0/MultiHeadDotProductAttention_1/out/kernel',
 'Transformer/encoderblock_0/MultiHeadDotProductAttention_1/query/bias',
 'Transformer/encoderblock_0/MultiHeadDotProductAttention_1/query/kernel',
 'Transformer/encoderblock_0/MultiHeadDotProductAttention_1/value/bias',
 'Tr

## Load ImageNet

## Fine-tune