In [26]:
# 참고 https://github.com/FrancescoSaverioZuppichini/ViT
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
import einops
from einops.layers.torch import Rearrange


In [27]:
x = torch.randn(8,3,224,224)
x.shape

torch.Size([8, 3, 224, 224])

Patch Embedding   
- 이미지를 Patch로 나누는 방법 2가지   
-- 1. einops의 rearrange   
-- 2. Covn2d layer로 patch크기와 같은 filter를 사용   
   
   Batch * C * H * W --> Batch * N * (P * P * C)   
   H * W --> N * ( P * P )   
- 실제의 VIT에서는 einops같은 Linear Embedding 보다 Conv2d Layer로 사용한 후 Flatten 한 것이 performance gain이 있습니다   
   
   -- google research 에서는 conv2d 후 jax.numpy로 reshape 

In [28]:
# 1. einops.rearrange 함수로 patch
# 8x3x(14*16)x(14*16) -> 8x(14*14)x(16*16*3) 으로 flatten
patch_size = 16

print(x.shape)
patches = einops.rearrange(x, 'b c (h s1) (w s2) -> b (h w) (s1 s2 c)', s1=patch_size, s2=patch_size)
print(f'patches : {patches.shape}')

torch.Size([8, 3, 224, 224])
patches : torch.Size([8, 196, 768])


In [29]:
# 2. convlayer로 patch 만들기
patch_size = 16
in_channels =3
emb_size = 768

projection = nn.Sequential(
    nn.Conv2d(in_channels, emb_size, kernel_size= patch_size,
    stride=patch_size),
    # einops.layers.torch.Rearrange 함수 사용
    Rearrange('b e (h) (w) -> b(h w) e')
)
projection(x).shape

torch.Size([8, 196, 768])

Class_Token과 Positional Encoding 코드

In [30]:
emb_size = 768
img_size = 224
patch_size = 16

# patch만들기
patch_x = projection(x)
print(f'Patch x shape : {patch_x.shape}')

#class token
cls_token = nn.Parameter(torch.randn(1,1,emb_size))
print(f'Class Token shape : {cls_token.shape}')

batch_size = 8
cls_token = einops.repeat(cls_token, '() n e -> b n e', b=batch_size)
print(f'Class Token after Repeat batch size : {cls_token.shape}')

# position encoding
# H position -> 224/16 =14  , W position -> 224/16 =14 , 14*14 만큼 position
# class와 patch를 concat해서 하나가 더 생긴다. 그러므로 포지션도 1개더 만들어줌
position = nn.Parameter(torch.randn((img_size//patch_size)**2+1, emb_size))
print(f'Position : {position.shape}')

# cls_token과 patch_x 를 concatenate
concat_x = torch.cat([cls_token,patch_x], dim=1)
print(f'concat x shape : {concat_x.shape}')

# posistion 을 더해준다.
concat_x += position


Patch x shape : torch.Size([8, 196, 768])
Class Token shape : torch.Size([1, 1, 768])
Class Token after Repeat batch size : torch.Size([8, 1, 768])
Position : torch.Size([197, 768])
concat x shape : torch.Size([8, 197, 768])


Class로 Patch embedding 구현   
-- Bool value of Tensor / tensor로 bool값으로 비교 하려 할때 나오는 error   
-- patchEmbedding class를 선언하고 x에 적용해야함

In [31]:
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels=3, patch_size=16, emb_size= 768, img_size=224):
        super().__init__()

        assert img_size % patch_size ==0, 'Image dimensions must be divisible by the patch size.'

        self.patch_size = patch_size
        self.projection = nn.Sequential(
            nn.Conv2d(
                in_channels, 
                emb_size, 
                kernel_size=patch_size,
                stride=patch_size
            ),
            Rearrange('b e (h) (w) -> b (h w) e'),
        )
        
        self.cls_token =nn.Parameter(torch.randn(1,1,emb_size))
        self.position = nn.Parameter(torch.randn((img_size//patch_size)**2+1, emb_size))
    
    def forward(self, x):
        b, _, _, _ = x.shape
        x = self.projection(x)
        cls_tokens = einops.repeat(self.cls_token, '() n e -> b n e',b=b)
        x = torch.cat([cls_tokens, x], dim=1)
        x += self.position
        
        return x 

In [32]:
Patch_Embedding = PatchEmbedding()
patch = Patch_Embedding(x)
print(f'patch shape : {patch.shape}')

patch shape : torch.Size([8, 197, 768])


Multi-Head Attention   
- Query, Key, Value 만들기

In [33]:
emb_size = 768
num_heads = 8

query = nn.Linear(emb_size, emb_size)
key = nn.Linear(emb_size,emb_size)
value = nn.Linear(emb_size,emb_size)
print(f'{query}\n{key}\n{value}')

Linear(in_features=768, out_features=768, bias=True)
Linear(in_features=768, out_features=768, bias=True)
Linear(in_features=768, out_features=768, bias=True)


In [34]:
print(f'query(x) shape : {query(patch).shape}')
query = einops.rearrange(query(patch), 'b n (h d) -> b h n d', h=num_heads)
key = einops.rearrange(key(patch), 'b n (h d) -> b h n d', h=num_heads)
value = einops.rearrange(value(patch), 'b n (h d) -> b h n d', h=num_heads)
print(f'query : {query.shape} \nkey : {key.shape}\nvalue : {value.shape}')

query(x) shape : torch.Size([8, 197, 768])
query : torch.Size([8, 8, 197, 96]) 
key : torch.Size([8, 8, 197, 96])
value : torch.Size([8, 8, 197, 96])


현재의 Query 에 대해 모든 Key값을 한번 씩 곱한다   
Query * Key^T 에 Softmax한 확률   
-> Softmax * value    

- matmul 와 einsum 2가지 방법이 있다.

In [35]:
# Query * Key
print(f'Query shape : {query.shape}')
print(f'Key shape : {key.shape}')
score = torch.matmul(query,key.transpose(-1,-2))
score2 = torch.einsum('bhqd, bhkd -> bhqk', query,key)
print(f'score shape : {score2.shape}')
print(f'score == score2 ? {(score==score2).all()}\n')

# Attention Score / emb_size 에 루트한 값을 나눈다
scaling = emb_size ** (1/2)
print(f'scaling : {scaling}')
score /= scaling
attention = torch.nn.functional.softmax(score, dim=-1)
print(f'attention : {attention.shape}')

# Attention score * value
out = torch.matmul(attention, value)
out2 = torch.einsum('bhal, bhlv -> bhav', attention, value)
print(f'Attention * value : {out.shape}')
print(f'out == out2 ?: {(out==out2).all()}')

# Rearrange to emb_size
out = einops.rearrange(out, 'b h n d -> b n (h d)')
print(f'output : {out.shape}')
print(f'patch와 동일한 크기가 나옴')

Query shape : torch.Size([8, 8, 197, 96])
Key shape : torch.Size([8, 8, 197, 96])
score shape : torch.Size([8, 8, 197, 197])
score == score2 ? True

scaling : 27.712812921102035
attention : torch.Size([8, 8, 197, 197])
Attention * value : torch.Size([8, 8, 197, 96])
out == out2 ?: True
output : torch.Size([8, 197, 768])
patch와 동일한 크기가 나옴


QKV 당 1개의 Linear Layer를 적용한 것을 텐서 연산을 한번에 하기위해   
emb_size *3으로 설정한 후 각각 나누어 준다.   
Attention 시 무시할 정보가 있을 경우 masking으로 하기 위해 

In [36]:
class MultiHeadAttention(nn.Module):
    def __init__(self, emb_size=768, num_heads=8, dropout = 0.):
        super().__init__()
        self.emb_size = emb_size
        self.num_heads = num_heads

        self.qkv= nn.Linear(emb_size, emb_size *3)
        self.att_drop = nn.Dropout(dropout)
        self.projection = nn.Linear(emb_size, emb_size)
        self.scaling = (emb_size//num_heads)**(-1/2)

    def forward(self, x, mask=None):
        qkv = einops.rearrange(self.qkv(x),'b n ( h d qkv) -> (qkv) b h n d', h=self.num_heads, qkv=3)

        query, key, value = qkv[0], qkv[1], qkv[2]

        score = torch.einsum('bhqd, bhkd -> bhqk', query,key)
        #print(f'score shape : {score.shape}')

        if mask is not None:
            fill_value = torch.finfo(torch.float32).min
            score.mask_fill(~mask, fill_value)
        
        score= score * self.scaling

        atten = torch.nn.functional.softmax(score,dim=-1)
        atten = self.att_drop(atten)
        #print(f'attention shape : {atten.shape}')

        out = torch.einsum('bhal, bhlv -> bhav',atten, value)
        out = einops.rearrange(out, 'b h n d -> b n (h d)')
        #print(f'out shape : {out.shape}')
        # 왜 마지막에 Linear 하는거지?
        out = self.projection(out)

        return out

In [37]:
Multihead = MultiHeadAttention()
output = Multihead(patch)
print(output.shape)

torch.Size([8, 197, 768])


Residual Block

In [38]:
class ResidualAdd(nn.Module):
    def __init__(self,fn):
        super().__init__()
        self.fn =fn
    
    def forward(self, x, **kwargs):
        res = x 
        x = self.fn(x, **kwargs)
        x += res
        return x

MLP Block


In [39]:
class FeedForwardBlock(nn.Sequential):
    def __init__(self, emb_size: int, expansion: int=4, drop_p : float =0.):
        super().__init__(
            nn.Linear(emb_size, expansion * emb_size),
            nn.GELU(),
            nn.Dropout(drop_p),
            nn.Linear(expansion * emb_size, emb_size),
        )

Transformer Encoder   
patch embding -> MultiHead attention -> MLP

In [40]:
class TransformerEncoderBlock(nn.Sequential):
    def __init__(
        self,
        emb_size = 768,
        drop_p = 0.,
        forward_expansion = 4,
        forward_drop_p = 0.,
        **kwargs):
        super().__init__(
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                MultiHeadAttention(emb_size,**kwargs),
                nn.Dropout(drop_p))
            ),
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                FeedForwardBlock(
                    emb_size, expansion=forward_expansion, drop_p=forward_drop_p
                ),
                nn.Dropout(drop_p))
            )
            )


In [41]:
class TransformerEncoder(nn.Sequential):
    def __init__(self, depth: int = 12, **kwargs):
        super().__init__(*[TransformerEncoderBlock(**kwargs) for _ in range(depth)])

In [42]:
class ClassificationHead(nn.Sequential):
    def __init__(self, emb_size: int = 768, n_classes:int =1000):
        super().__init__(
            einops.layers.torch.Reduce('b n e -> b e', reduction='mean'),
            nn.LayerNorm(emb_size),
            nn.Linear(emb_size, n_classes)
        )

In [43]:
class ViT(nn.Sequential):
    def __init__(
        self,
        in_channels: int = 3,
        patch_size : int = 16,
        emb_size : int = 768,
        img_size : int =224,
        depth: int = 12,
        n_classes: int = 1000,
        **kwargs):
        super().__init__(
            PatchEmbedding(in_channels, patch_size, emb_size, img_size),
            TransformerEncoder(depth, emb_size = emb_size, **kwargs),
            ClassificationHead(emb_size, n_classes)
        )

In [44]:
from torchsummary import summary
summary(ViT(), (3,224,224), device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 768, 14, 14]         590,592
         Rearrange-2             [-1, 196, 768]               0
    PatchEmbedding-3             [-1, 197, 768]               0
         LayerNorm-4             [-1, 197, 768]           1,536
            Linear-5            [-1, 197, 2304]       1,771,776
           Dropout-6          [-1, 8, 197, 197]               0
            Linear-7             [-1, 197, 768]         590,592
MultiHeadAttention-8             [-1, 197, 768]               0
           Dropout-9             [-1, 197, 768]               0
      ResidualAdd-10             [-1, 197, 768]               0
        LayerNorm-11             [-1, 197, 768]           1,536
           Linear-12            [-1, 197, 3072]       2,362,368
             GELU-13            [-1, 197, 3072]               0
          Dropout-14            [-1, 19