In [27]:
# 참고 https://visionhong.tistory.com/25
import torch
import torch.nn as nn
from PIL import Image
import einops

In [28]:
class PatchEmbed(nn.Module):
    def __init__(self, img_size, patch_size, in_channels=3, emb_dim=768):
        super(PatchEmbed, self).__init__()
        assert img_size % patch_size ==0, 'Image dimensions must be divisible by the patch size.'

        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patch = (img_size//patch_size)**2

        self.proj = nn.Conv2d(
            in_channels=in_channels,
            out_channels=emb_dim,
            kernel_size = patch_size,
            stride=patch_size,
        )
    def forward(self,x):
        # [ b, emb_size, num_patch ** 0.5 , num_patch ** 0.5 ]
        x = self.proj(x)

        # # [ b, emb_size, num_patch ] 
        # x = x.flatten(2)
        # # [ b, num_patch, emb_size ]
        # x = x.transpose(1,2)
        
        # or
        x = einops.rearrange(x, 'b e w h -> b (w h) e')
        return x

In [29]:
test = torch.zeros(1,3,224,224)
patch_emb = PatchEmbed(img_size=224,patch_size=16)
x = patch_emb(test)
x.shape

torch.Size([1, 196, 768])

In [53]:
class MultiHeadAttetion(nn.Module):
    def __init__(self, emb_dim=768, n_heads=12, qkv_bias=True, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.n_heads= n_heads
        self.emb_dim = emb_dim
        self.head_dim = emb_dim // n_heads
        self.scale = self.head_dim ** (-0.5)

        self.q = nn.Linear(emb_dim, emb_dim, bias=qkv_bias)
        self.k = nn.Linear(emb_dim, emb_dim, bias=qkv_bias)
        self.v = nn.Linear(emb_dim, emb_dim, bias=qkv_bias)
        
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(emb_dim, emb_dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self,x):
        b, num_patch, dim = x.shape
        if dim != self.emb_dim:
            raise ValueError

        query = self.q(x)
        key = self.k(x)
        value = self.v(x)
        
        # [batch, num_patch+1, emb_size] -> [bach, n_head, num_patch+1, emb_size/n_head]
        query = einops.rearrange(query,'b p (h w) -> b h p w', h=self.n_heads)
        key = einops.rearrange(key,'b p (h w) -> b h p w', h=self.n_heads)
        value = einops.rearrange(value,'b p (h w) -> b h p w', h=self.n_heads)
        
        # score = query * key^t
        # [batch, n_head, num_patch+1, num_patch+1] / result
        score = torch.einsum('bhqd , bhkd -> bhqk', query,key) *self.scale
        
        # 0~1로 
        atten = torch.nn.functional.softmax(score, dim= -1)
        atten = self.attn_drop(atten)
        
        # atten * value
        # [batch, n_head, num_patch+1, num_patch+1]  -> [batch, n_head, num_patch+1, num_patch+1]
        out = torch.einsum('bhaz , bhzd -> bhad', atten,value)
        # [batch, num_patch+1, emb_size]
        out = einops.rearrange(out,'b h p e -> b p (h e)')


        # attention 의 차원을 조절 할 수 있음
        x = self.proj(out)
        x = self.proj_drop(x)
        return x


In [54]:
Mutliattention = MultiHeadAttetion()
result = Mutliattention(x)

In [64]:
class MLP(nn.Module):
    def __init__(self, emb_size:int=768, expand_ratio:int=4, output_size:int=768, p:float =0.):
        super().__init__()
        self.fc1 = nn.Linear(emb_size,emb_size * expand_ratio)
        # GELU -> Gaussian error Linear unit 수렴속도가 빠르다
        self.gelu = nn.GELU()
        self.fc2 = nn.Linear(emb_size * expand_ratio, output_size)
        self.drop = nn.Dropout(p)
    
    def forward(self, x):
        x = self.fc1(x)
        x = self.gelu(x)
        x = self.drop(x)
        x = self.fc2(x)
        return x

In [65]:
class ViTEncoderBlock(nn.Module):
    def __init__(self, emb_size, n_heads:int=12, mlp_ratio:int=4,qkv_bias=True,attn_p=0.,p=0.):
        super().__init__()
        self.norm1 = nn.LayerNorm(emb_size,eps=1e-6)
        self.attn = MultiHeadAttetion(emb_size,n_heads,qkv_bias,attn_p,proj_drop=p)
        self.norm2 = nn.LayerNorm(emb_size,eps=1e-6)
        self.mlp = MLP(emb_size,expand_ratio=mlp_ratio,output_size=emb_size)

    def forward(self,x):
        # residual block
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

In [70]:
class ViT(nn.Module):
    def __init__(
        self,
        img_size=224,
        patch_size=16,
        in_channels=3,
        n_classes=1000,
        emb_size = 768,
        depth=12,
        n_heads=12,
        mlp_ratio=4,
        qkv_bias=True,
        p=0.,
        attn_p=0.):
        super().__init__()
        
        self.patch_embed = PatchEmbed(
            img_size=img_size,patch_size=patch_size,in_channels=in_channels,emb_dim=emb_size
        )
        
        # class 
        self.cls_token =nn.Parameter(torch.zeros(1,1,emb_size))
        self.pos_emb = nn.Parameter(torch.zeros(1,1+self.patch_embed.num_patch,emb_size))
        self.pos_drop = nn.Dropout(p)

        self.blocks = nn.ModuleList(
            [
                ViTEncoderBlock(
                    emb_size=emb_size,
                    n_heads=n_heads,
                    mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias,
                    attn_p=attn_p,
                    p=p
                )
                for _ in range(depth)    
            ]
        )

        self.norm=nn.LayerNorm(emb_size, eps=1e-6)
        self.head = nn.Linear(emb_size, n_classes)
    
    def forward(self,x):
        n_samples = x.shape[0]
        # [batch, num_patch, emb_dim]
        x = self.patch_embed(x)
        # [batch, 1, emb_dim]
        cls_token = self.cls_token.expand(n_samples,-1,-1)
        # [batch, num_patch+1, emb_dim]
        x = torch.cat((cls_token,x), dim=1)

        x = x + self.pos_emb
        x = self.pos_drop(x)

        for block in self.blocks:
            x = block(x)

        x = self.norm(x)
        # [batch, num_patch+1, emb_dim]
 
        # [batch, emb_dim] // 맨앞 class의 emb_dim
        # 이미지 전체의 embedding 을 표현할 것으로 가정
        cls_token_final = x[:,0]
        
        # [batch, n_classes] // 예측
        x = self.head(cls_token_final)
        x = torch.nn.functional.softmax(x, dim=1)

        return x
        

In [71]:
model = ViT()
result = model(test)
print(result.shape)

torch.Size([1, 197, 768])
torch.Size([1, 768])
torch.Size([1, 1000])
