In [2]:
from vision_transformers import models
from vision_transformers.utils.params import params
from torch.hub import load_state_dict_from_url

import torch

In [3]:
model = models.vit_b_16()
params(model)

86,567,656 total parameters.
86,567,656 training parameters.


In [4]:
print(model)

ViT(
  (patches): CreatePatches(
    (patch): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  )
  (dropout): Dropout(p=0.0, inplace=False)
  (transformer): Transformer(
    (layers): ModuleList(
      (0): ModuleList(
        (0): Normalization(
          (norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (fn): Attention(
            (qkv): Linear(in_features=768, out_features=2304, bias=True)
            (out): Sequential(
              (0): Linear(in_features=768, out_features=768, bias=True)
              (1): Dropout(p=0.0, inplace=False)
            )
          )
        )
        (1): Normalization(
          (norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (fn): MLP(
            (mlp_net): Sequential(
              (0): Linear(in_features=768, out_features=3072, bias=True)
              (1): GELU(approximate=none)
              (2): Dropout(p=0.0, inplace=False)
              (3): Linear(in_features=3072, out_features=768, bias

In [4]:
model.transformer.layers[0]

ModuleList(
  (0): Normalization(
    (norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
    (fn): Attention(
      (qkv): Linear(in_features=768, out_features=2304, bias=True)
      (out): Sequential(
        (0): Linear(in_features=768, out_features=768, bias=True)
        (1): Dropout(p=0.0, inplace=False)
      )
    )
  )
  (1): Normalization(
    (norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
    (fn): MLP(
      (mlp_net): Sequential(
        (0): Linear(in_features=768, out_features=3072, bias=True)
        (1): GELU(approximate=none)
        (2): Dropout(p=0.0, inplace=False)
        (3): Linear(in_features=3072, out_features=768, bias=True)
        (4): Dropout(p=0.0, inplace=False)
      )
    )
  )
)

## Try Loading Torchvision Weights

In [6]:
weights = load_state_dict_from_url('https://download.pytorch.org/models/vit_b_16-c867db91.pth')

In [7]:
state_dict = model.state_dict()

In [8]:
print(list(state_dict.keys())[2])

patches.patch.weight


In [9]:
for i in range(len(state_dict.keys())):
    print(i)
    print(list(state_dict.keys())[i])

0
pos_embedding
1
cls_token
2
patches.patch.weight
3
patches.patch.bias
4
transformer.layers.0.0.norm.weight
5
transformer.layers.0.0.norm.bias
6
transformer.layers.0.0.fn.qkv.weight
7
transformer.layers.0.0.fn.qkv.bias
8
transformer.layers.0.0.fn.out.0.weight
9
transformer.layers.0.0.fn.out.0.bias
10
transformer.layers.0.1.norm.weight
11
transformer.layers.0.1.norm.bias
12
transformer.layers.0.1.fn.mlp_net.0.weight
13
transformer.layers.0.1.fn.mlp_net.0.bias
14
transformer.layers.0.1.fn.mlp_net.3.weight
15
transformer.layers.0.1.fn.mlp_net.3.bias
16
transformer.layers.1.0.norm.weight
17
transformer.layers.1.0.norm.bias
18
transformer.layers.1.0.fn.qkv.weight
19
transformer.layers.1.0.fn.qkv.bias
20
transformer.layers.1.0.fn.out.0.weight
21
transformer.layers.1.0.fn.out.0.bias
22
transformer.layers.1.1.norm.weight
23
transformer.layers.1.1.norm.bias
24
transformer.layers.1.1.fn.mlp_net.0.weight
25
transformer.layers.1.1.fn.mlp_net.0.bias
26
transformer.layers.1.1.fn.mlp_net.3.weight
27

In [10]:
def load_pretrained_state_dict(model, model_name='vit_b_16'):
    weights = load_state_dict_from_url(urls[model_name])
    # Model's current state dictionary.
    state_dict = model.state_dict()

    if model_name == 'vit_b_16' or model_name == 'vit_b_32':
        state_dict['cls_token'] = weights['class_token']
        state_dict['pos_embedding'] = weights['encoder.pos_embedding']
        state_dict['patches.patch.weight'] = weights['conv_proj.weight']
        state_dict['patches.patch.bias'] = weights['conv_proj.bias']
        
        for i in range(12):
            state_dict[f"transformer.layers.{i}.0.norm.weight"] = weights[f"encoder.layers.encoder_layer_{i}.ln_1.weight"]
            state_dict[f"transformer.layers.{i}.0.norm.bias"] = weights[f"encoder.layers.encoder_layer_{i}.ln_1.bias"]
            state_dict[f"transformer.layers.{i}.0.fn.qkv.weight"] = weights[f"encoder.layers.encoder_layer_{i}.self_attention.in_proj_weight"]
            state_dict[f"transformer.layers.{i}.0.fn.qkv.bias"] = weights[f"encoder.layers.encoder_layer_{i}.self_attention.in_proj_bias"]
            state_dict[f"transformer.layers.{i}.0.fn.out.0.weight"] = weights[f"encoder.layers.encoder_layer_{i}.self_attention.out_proj.weight"]
            state_dict[f"transformer.layers.{i}.0.fn.out.0.bias"] = weights[f"encoder.layers.encoder_layer_{i}.self_attention.out_proj.bias"]
            state_dict[f"transformer.layers.{i}.1.norm.weight"] = weights[f"encoder.layers.encoder_layer_{i}.ln_2.weight"]
            state_dict[f"transformer.layers.{i}.1.norm.bias"] = weights[f"encoder.layers.encoder_layer_{i}.ln_2.bias"]
            state_dict[f"transformer.layers.{i}.1.fn.mlp_net.0.weight"] = weights[f"encoder.layers.encoder_layer_{i}.mlp.linear_1.weight"]
            state_dict[f"transformer.layers.{i}.1.fn.mlp_net.0.bias"] = weights[f"encoder.layers.encoder_layer_{i}.mlp.linear_1.bias"]
            state_dict[f"transformer.layers.{i}.1.fn.mlp_net.3.weight"] = weights[f"encoder.layers.encoder_layer_{i}.mlp.linear_2.weight"]
            state_dict[f"transformer.layers.{i}.1.fn.mlp_net.3.bias"] = weights[f"encoder.layers.encoder_layer_{i}.mlp.linear_2.bias"]
            
        state_dict['ln.weight'] = weights['encoder.ln.weight']
        state_dict['ln.bias'] = weights['encoder.ln.bias']
        state_dict['mlp_head.0.weight'] = weights['heads.head.weight']
        state_dict['mlp_head.0.bias'] = weights['heads.head.bias']
    model.load_state_dict(state_dict)
    return model


In [11]:
state_dict = load_pretrained_state_dict(state_dict, weights)
model.load_state_dict(state_dict)

NameError: name 'urls' is not defined

## Try Loading timm Weights 

In [1]:
weights = torch.hub.load_state_dict_from_url('https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth')

for i in range(len(weights.keys())):
    print(i)
    print(list(weights.keys())[i])

  from .autonotebook import tqdm as notebook_tqdm
Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth" to /home/sovit/.cache/torch/hub/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth
100%|██████████| 330M/330M [00:52<00:00, 6.60MB/s] 


0
cls_token
1
norm.bias
2
norm.weight
3
blocks.0.norm1.bias
4
blocks.0.norm1.weight
5
blocks.0.norm2.bias
6
blocks.0.norm2.weight
7
blocks.0.mlp.fc1.bias
8
blocks.0.mlp.fc1.weight
9
blocks.0.mlp.fc2.bias
10
blocks.0.mlp.fc2.weight
11
blocks.0.attn.proj.bias
12
blocks.0.attn.proj.weight
13
blocks.0.attn.qkv.bias
14
blocks.0.attn.qkv.weight
15
blocks.1.norm1.bias
16
blocks.1.norm1.weight
17
blocks.1.norm2.bias
18
blocks.1.norm2.weight
19
blocks.1.mlp.fc1.bias
20
blocks.1.mlp.fc1.weight
21
blocks.1.mlp.fc2.bias
22
blocks.1.mlp.fc2.weight
23
blocks.1.attn.proj.bias
24
blocks.1.attn.proj.weight
25
blocks.1.attn.qkv.bias
26
blocks.1.attn.qkv.weight
27
blocks.10.norm1.bias
28
blocks.10.norm1.weight
29
blocks.10.norm2.bias
30
blocks.10.norm2.weight
31
blocks.10.mlp.fc1.bias
32
blocks.10.mlp.fc1.weight
33
blocks.10.mlp.fc2.bias
34
blocks.10.mlp.fc2.weight
35
blocks.10.attn.proj.bias
36
blocks.10.attn.proj.weight
37
blocks.10.attn.qkv.bias
38
blocks.10.attn.qkv.weight
39
blocks.11.norm1.bias
40


In [15]:
def load_pretrained_state_dict(model, model_name='vit_b_16'):
    weights = load_state_dict_from_url('https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth')
    # Model's current state dictionary.
    state_dict = model.state_dict()

    if model_name == 'vit_b_16' or model_name == 'vit_b_32':
        state_dict['cls_token'] = weights['cls_token']
        state_dict['pos_embedding'] = weights['pos_embed']
        state_dict['patches.patch.weight'] = weights['patch_embed.proj.weight']
        state_dict['patches.patch.bias'] = weights['patch_embed.proj.bias']
        
        for i in range(12):
            state_dict[f"transformer.layers.{i}.0.norm.weight"] = weights[f"blocks.{i}.norm1.weight"]
            state_dict[f"transformer.layers.{i}.0.norm.bias"] = weights[f"blocks.{i}.norm1.bias"]
            state_dict[f"transformer.layers.{i}.0.fn.qkv.weight"] = weights[f"blocks.{i}.attn.qkv.weight"]
            state_dict[f"transformer.layers.{i}.0.fn.qkv.bias"] = weights[f"blocks.{i}.attn.qkv.bias"]
            state_dict[f"transformer.layers.{i}.0.fn.out.0.weight"] = weights[f"blocks.{i}.attn.proj.weight"]
            state_dict[f"transformer.layers.{i}.0.fn.out.0.bias"] = weights[f"blocks.{i}.attn.proj.bias"]
            state_dict[f"transformer.layers.{i}.1.norm.weight"] = weights[f"blocks.{i}.norm2.weight"]
            state_dict[f"transformer.layers.{i}.1.norm.bias"] = weights[f"blocks.{i}.norm2.bias"]
            state_dict[f"transformer.layers.{i}.1.fn.mlp_net.0.weight"] = weights[f"blocks.{i}.mlp.fc1.weight"]
            state_dict[f"transformer.layers.{i}.1.fn.mlp_net.0.bias"] = weights[f"blocks.{i}.mlp.fc1.bias"]
            state_dict[f"transformer.layers.{i}.1.fn.mlp_net.3.weight"] = weights[f"blocks.{i}.mlp.fc2.weight"]
            state_dict[f"transformer.layers.{i}.1.fn.mlp_net.3.bias"] = weights[f"blocks.{i}.mlp.fc2.bias"]
            
        state_dict['ln.weight'] = weights['norm.weight']
        state_dict['ln.bias'] = weights['norm.bias']
        # MAYBE no need to load head weights.
        state_dict['mlp_head.0.weight'] = weights['head.weight']
        state_dict['mlp_head.0.bias'] = weights['head.bias']
    model.load_state_dict(state_dict)
    return model

In [16]:
load_pretrained_state_dict(model)

ViT(
  (patches): CreatePatches(
    (patch): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  )
  (dropout): Dropout(p=0.0, inplace=False)
  (transformer): Transformer(
    (layers): ModuleList(
      (0): ModuleList(
        (0): Normalization(
          (norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (fn): Attention(
            (qkv): Linear(in_features=768, out_features=2304, bias=True)
            (out): Sequential(
              (0): Linear(in_features=768, out_features=768, bias=True)
              (1): Dropout(p=0.0, inplace=False)
            )
          )
        )
        (1): Normalization(
          (norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (fn): MLP(
            (mlp_net): Sequential(
              (0): Linear(in_features=768, out_features=3072, bias=True)
              (1): GELU(approximate=none)
              (2): Dropout(p=0.0, inplace=False)
              (3): Linear(in_features=3072, out_features=768, bias