In [27]:
import torch
import torch.nn as nn
from PIL import Image
import einops

In [28]:
class PatchEmbed(nn.Module):
    def __init__(self, img_size, patch_size, in_channels=3, emb_dim=768):
        super(PatchEmbed, self).__init__()
        assert img_size % patch_size ==0, 'Image dimensions must be divisible by the patch size.'

        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patch = (img_size//patch_size)**2

        self.proj = nn.Conv2d(
            in_channels=in_channels,
            out_channels=emb_dim,
            kernel_size = patch_size,
            stride=patch_size,
        )
    def forward(self,x):
        # [ b, emb_size, num_patch ** 0.5 , num_patch ** 0.5 ]
        x = self.proj(x)

        # # [ b, emb_size, num_patch ] 
        # x = x.flatten(2)
        # # [ b, num_patch, emb_size ]
        # x = x.transpose(1,2)
        
        # or
        x = einops.rearrange(x, 'b e w h -> b (w h) e')
        return x

In [29]:
test = torch.zeros(1,3,224,224)
patch_emb = PatchEmbed(img_size=224,patch_size=16)
x = patch_emb(test)
x.shape

torch.Size([1, 196, 768])

In [53]:
class MultiHeadAttetion(nn.Module):
    def __init__(self, emb_dim=768, n_heads=12, qkv_bias=True, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.n_heads= n_heads
        self.emb_dim = emb_dim
        self.head_dim = emb_dim // n_heads
        self.scale = self.head_dim ** (-0.5)

        self.q = nn.Linear(emb_dim, emb_dim, bias=qkv_bias)
        self.k = nn.Linear(emb_dim, emb_dim, bias=qkv_bias)
        self.v = nn.Linear(emb_dim, emb_dim, bias=qkv_bias)
        
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(emb_dim, emb_dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self,x):
        b, num_patch, dim = x.shape
        if dim != self.emb_dim:
            raise ValueError

        query = self.q(x)
        key = self.k(x)
        value = self.v(x)
        
        # [batch, num_patch+1, emb_size] -> [bach, n_head, num_patch+1, emb_size/n_head]
        query = einops.rearrange(query,'b p (h w) -> b h p w', h=self.n_heads)
        key = einops.rearrange(key,'b p (h w) -> b h p w', h=self.n_heads)
        value = einops.rearrange(value,'b p (h w) -> b h p w', h=self.n_heads)
        
        # score = query * key^t
        # [batch, n_head, num_patch+1, num_patch+1] / result
        score = torch.einsum('bhqd , bhkd -> bhqk', query,key) *self.scale
        
        # 0~1로 
        atten = torch.nn.functional.softmax(score, dim= -1)
        atten = self.attn_drop(atten)
        
        # atten * value
        # [batch, n_head, num_patch+1, num_patch+1]  -> [batch, n_head, num_patch+1, num_patch+1]
        out = torch.einsum('bhaz , bhzd -> bhad', atten,value)
        # [batch, num_patch+1, emb_size]
        out = einops.rearrange(out,'b h p e -> b p (h e)')


        # attention 의 차원을 조절 할 수 있음
        x = self.proj(out)
        x = self.proj_drop(x)
        return x


In [54]:
Mutliattention = MultiHeadAttetion()
result = Mutliattention(x)

In [55]:
class MLP(nn.Module):
    def __init__(self, in_features, hidden_features, num_classes, p =0.):
        super().__init__()
        self.fc1 = nn.Linear(in_features,hidden_features)
        # GELU -> Gaussian error Linear unit 수렴속도가 빠르다
        self.gelu = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, num_classes)
        self.drop = nn.Dropout(p)
    
    def forward(self, x):
        x = self.fc1(x)
        x = self.gelu(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x