In [179]:
from transformers import ViTImageProcessor, ViTModel
from PIL import Image
import requests
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass

In [219]:
class ViTPatchEmbeddings(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.projection = nn.Conv2d(config.num_channels, config.hidden_size, (config.patch_size,config.patch_size), (config.patch_size,config.patch_size))
        
    def forward(self, x):
        return self.projection(x).flatten(2).transpose(1,2)

In [237]:
class ViTEmbeddings(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.cls_token = nn.Parameter(torch.randn(1,1, config.hidden_size))
        self.patch_embeddings = ViTPatchEmbeddings(config)
        num_patches = (config.image_size // config.patch_size)**2
        self.position_embeddings = nn.Parameter(torch.randn(1,197, config.hidden_size))
        self.dropout = nn.Dropout(0.0, inplace=False)
    
    def forward(self, x):
        bs = x.shape[0]
        patches = self.patch_embeddings(x)
        cls_tokens = self.cls_token.expand(bs, -1, -1)
        patches = torch.cat((cls_tokens,patches),1)
        embeds = patches + self.position_embeddings
        embeds = self.dropout(embeds)
        
        return embeds

In [238]:
class ViTSdpaSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_head = config.num_attention_heads
        self.query = nn.Linear(config.hidden_size, config.hidden_size)
        self.key = nn.Linear(config.hidden_size, config.hidden_size)        
        self.value = nn.Linear(config.hidden_size, config.hidden_size)
        self.dropout = nn.Dropout(0.0, inplace=False)
    
    def forward(self, x):
        B, T, C = x.size()
        q, k, v = self.query(x), self.key(x), self.value(x)
        q = q.view(B, T, self.n_head, C//self.n_head).transpose(1,2)
        k = k.view(B, T, self.n_head, C//self.n_head).transpose(1,2)
        v = v.view(B, T, self.n_head, C//self.n_head).transpose(1,2)
        y = F.scaled_dot_product_attention(q, k, v)
        y = y.transpose(1,2).contiguous().view(B, T, C)
        
        return y
        

In [239]:
class ViTSelfOutput(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.dropout = nn.Dropout(0.0, inplace=False)
    
    def forward(self, x):
        x = self.dense(x)
        return x

In [240]:
class ViTSdpaAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.attention = ViTSdpaSelfAttention(config)
        self.output = ViTSelfOutput(config)
        
    def forward(self, x):
        x = self.attention(x)
        x = self.output(x)
        return x

In [241]:
class ViTIntermediate(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
        self.intermediate_act_fn = nn.GELU()
    
    def forward(self, x):
        x = self.dense(x)
        x = self.intermediate_act_fn(x)
        return x

In [242]:
class ViTOutput(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
        self.dropout = nn.Dropout(0.0, inplace=False)
        
    def forward(self, x):
        x = self.dense(x)
        return x

In [259]:
class ViTLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.attention = ViTSdpaAttention(config)
        self.intermediate = ViTIntermediate(config)
        self.output = ViTOutput(config)
        self.layernorm_before = nn.LayerNorm(config.hidden_size,eps=config.layer_norm_eps)
        self.layernorm_after = nn.LayerNorm(config.hidden_size,eps=config.layer_norm_eps)
        
    def forward(self, x):
        attn = self.attention(self.layernorm_before(x))
        x = attn + x
        out = self.layernorm_after(x)
        out = self.intermediate(out)
        out = self.output(out) + x
        return out

In [260]:
class ViTEncoder(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.layer = nn.ModuleList([ViTLayer(config) for _ in range(config.num_hidden_layers)])
        
    def forward(self, x):
        for layer in self.layer:
            x = layer(x)
        return x

In [265]:
class ViTPooler(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()
    
    def forward(self, x):
        x = self.dense(x[:,0])
        x = self.activation(x)
        return x

In [266]:
@dataclass
class BaseOutput:
    last_hidden_state: torch.tensor
    pooler_output: torch.tensor
        
@dataclass
class Config:
    _name_or_path= "google/vit-base-patch16-224-in21k"
    attention_probs_dropout_prob= 0.0
    hidden_act= "gelu"
    hidden_dropout_prob= 0.0
    hidden_size= 768
    image_size= 224
    initializer_range= 0.02
    intermediate_size= 3072
    layer_norm_eps= 1e-12
    model_type= "vit"
    num_attention_heads= 12
    num_channels= 3 
    num_hidden_layers= 12
    patch_size= 16
    qkv_bias= True

In [271]:
class MyViTModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.embeddings = ViTEmbeddings(config)
        self.encoder = ViTEncoder(config)
        self.layernorm = nn.LayerNorm(config.hidden_size,eps=config.layer_norm_eps)
        self.pooler = ViTPooler(config)
        
    def forward(self,pixel_values):
        out = self.embeddings(pixel_values)
        out = self.encoder(out)
        last_hidden_state = self.layernorm(out)
        pooler_output = self.pooler(last_hidden_state)
        return BaseOutput(
            last_hidden_state = last_hidden_state, 
            pooler_output = pooler_output
        )
    
    @classmethod
    def from_pretrained(cls):
        model = MyViTModel(Config())
        model_hf = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
        
        sd = model.state_dict()
        sd_keys = sd.keys()
        
        sd_hf = model_hf.state_dict()
        sd_keys_hf = sd_hf.keys()
        
        for k in sd_keys_hf:
            assert sd_hf[k].shape == sd[k].shape, f"VAS {k}"
            with torch.no_grad():
                sd[k].copy_(sd_hf[k])
        
        return model

In [272]:
m = MyViTModel.from_pretrained()

In [273]:
model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')

In [274]:
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
processor = ViTImageProcessor.from_pretrained('google/vit-large-patch16-224-in21k')
inputs = processor(images=image, return_tensors="pt")

In [275]:
outputs = model(**inputs)
last1 = outputs.last_hidden_state
pool1 = outputs.pooler_output
last1.shape, pool1.shape

(torch.Size([1, 197, 768]), torch.Size([1, 768]))

In [276]:
outputs = m(**inputs)
last1 = outputs.last_hidden_state
pool1 = outputs.pooler_output
last1.shape, pool1.shape

(torch.Size([1, 197, 768]), torch.Size([1, 768]))

In [277]:
torch.equal(last1, last2), torch.equal(pool1, pool2)

(True, True)