In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from einops import rearrange, repeat

In [2]:
def attention(query, key, value, dropout):
    '''
    query, key, value: (b, h, n, d)
    p_attn: (b, h, n, n)
    result: (b, h, n, d)
    '''
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    p_attn = F.softmax(scores, dim = -1)
    p_attn = dropout(p_attn)
    result = torch.matmul(p_attn, value)
    return result

In [3]:
class MSA(nn.Module):
    def __init__(self, h, d_model, p_dropout):
        '''
        h: the number of head
        d_model: the dimensions of input token vector
        p_dropout: probability of dropout 
        '''
        
        super(MSA, self).__init__()
        assert d_model % h == 0
        
        self.d_k = d_model // h
        self.h = h
        self.token_linears = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(3)])
        self.fin_linear = nn.Linear(d_model, d_model)
        self.attn = None
        self.dropout = nn.Dropout(p=p_dropout)
        
    def forward(self, z):
        num_batches = z.size(0)
        
        query, key, value = \
            [l(z).view(num_batches, -1, self.h, self.d_k).transpose(1, 2)
             for l in self.token_linears ]
        
        attn = attention(query, key, value, self.dropout)
        
        attn = attn.transpose(1, 2).contiguous().view(num_batches, -1, self.h * self.d_k)
        
        return self.fin_linear(attn)


In [4]:
class FFN(nn.Module):
    def __init__(self, d_model , d_ff, p_dropout):
        '''
        d_model: the dimensions of input token vector
        d_ff: intermediate dimensions of linear mapping
        p_dropout: probability of dropout 
        '''
        
        super(FFN, self).__init__()

        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(p_dropout)
        
    def forward(self, x):
        x = F.gelu(self.w_1(x))
        x = self.dropout(x)
        return self.w_2(x)

In [5]:
class Transformer(nn.Module):
    def __init__(self, num_layer, num_head, dim_model, p_dropout, d_ff):
        super(Transformer, self).__init__()
        self.msa_list = nn.ModuleList([MSA(num_head, dim_model, p_dropout) for _ in range(num_layer)])
        self.ffn_list = nn.ModuleList([FFN(dim_model, d_ff, p_dropout) for _ in range(num_layer)])
        
    def forward(self, z):
        for msa, ffn in zip(self.msa_list, self.ffn_list):
            z = msa(z) + z
            z = ffn(z) + z
        return z

In [6]:
class ViT(nn.Module):
    def __init__(self, img_size=128, patch_size=16, num_cls=10, dim_model=128, num_layer=12, num_head=8, p_dropout=0.1, d_ff=256):
        super(ViT, self).__init__()
        
        self.img_size = img_size
        self.num_patch = (img_size // patch_size) ** 2 
        assert img_size % patch_size == 0 , 'img size must be divisible by patch size'    
        self.patch_dim = patch_size * patch_size * 3
        self.patch_size = patch_size
        self.pos = nn.Parameter(torch.randn(1, self.num_patch + 1, dim_model))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim_model))
        self.emb = nn.Linear(self.patch_dim, dim_model)
        self.transformer = Transformer(num_layer, num_head, dim_model, p_dropout, d_ff)
        self.mlp_head = nn.Linear(dim_model, num_cls)
            

    def forward(self, x, is_fine=False):
        # img to patch
        _, _, h, w = x.shape
        x = rearrange(x, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = self.patch_size, p2 = self.patch_size)
        x = self.emb(x)
        cls_token = repeat(self.cls_token, '1 n d -> b n d', b = x.shape[0])
        x = torch.cat((cls_token, x), dim=1)
        if not is_fine:
            x += self.pos
        else:
            assert h % self.patch_size == 0 and w % self.patch_size == 0 , 'img size must be divisible by patch size'
            inp_h = h // self.patch_size
            inp_w = w // self.patch_size
            
            pos = rearrange(self.pos[:, 1:, :], 'b (h w) d -> b d h w', w=self.img_size // self.patch_size)
            pos = F.interpolate(pos, (inp_w, inp_h), mode='bilinear')
            pos = rearrange(pos, 'b d h w -> b (h w) d')
            pos = torch.cat([self.pos[:, :1, :], pos], dim=1)
            x += pos

        x = self.transformer(x)
        y = self.mlp_head(x[:, 0, :])
        return y

In [8]:
vit = ViT()
x = torch.randn((1, 3, 128, 128))
print(x.shape)

torch.Size([1, 3, 128, 128])


In [9]:
y = vit(x)
print(y.shape)

torch.Size([1, 10])


In [11]:
high_resolution_x = torch.randn((1, 3, 256, 256))
print(high_resolution_x.shape)

torch.Size([1, 3, 256, 256])


In [12]:
y = vit(x, is_fine=True)
print(y.shape)

torch.Size([1, 10])


