In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torchsummary import summary

In [2]:
class patch_embedding(nn.Module) :
    def __init__(self, patch_size, img_size, embed_size) :
        super(patch_embedding, self).__init__()
        
        self.patch_embedding = nn.Conv2d(3, embed_size, 
                                         kernel_size=patch_size, 
                                         stride=patch_size)
        # cls token을 패치 앞에 하나 더 붙여줌
        # ex) 9개의 패치로 이미지를 나누면 앞에 cls token이 붙어 최종적으로 10개의 패치가 되는 셈
        self.cls_token = nn.Parameter(torch.rand(1,1,embed_size))
        
        # cls token 1개가 더 붙었기 때문에 총 patch 개수에 + 1을 해줌
        self.position = nn.Parameter(torch.rand((img_size//patch_size)**2 + 1, embed_size))
    
    def forward(self, x) :
        x = self.patch_embedding(x)
        x = x.flatten(2)
        x = x.transpose(2,1)

        ct = self.cls_token.repeat(x.shape[0], 1, 1)
        x = torch.cat([ct, x],dim=1)
        x += self.position
        return x

In [3]:
data = torch.rand(5,3,224,224)
pe = patch_embedding(16, 224, 768)
y = pe(data)
print(y.shape)

torch.Size([5, 197, 768])


In [4]:
class multi_head_attention(nn.Module) :
    def __init__(self, embed_size, num_head, dropout_rate=0.1) :
        super(multi_head_attention, self).__init__()
        
        self.q = nn.Linear(embed_size, embed_size)
        self.k = nn.Linear(embed_size, embed_size)
        self.v = nn.Linear(embed_size, embed_size)
        
        self.fc = nn.Linear(embed_size, embed_size)
        self.dropout = nn.Dropout(dropout_rate)
        
        self.num_head = num_head
        self.embed_size = embed_size
    
    # 내가 쓴 노션(study list/Trasformer/Q,K,V 벡터 얻기)에 보면 
    # QKV는 각각 d / num_head 차원을 가져야 한다고 나와있음
    # 여기서는 d 가 embed_size 임
    # 따라서 (batch, patch_num, embed_size)에서 embed_size를 embed_size/num_head로 변경해야함
    # (batch, patch_num, num_head, embed_size//num_head)로 표현가능
    # 행렬을 이용하여 병렬로 처리 할 것이기 때문에
    # (batch, num_head, patch_num, embed_size//num_head)로 표현 함
    # (patch_num, embed_size//num_head)가 num_head 개수 만큼 있어야 1개의 값을 완성 한다고 보면됨
    def qkv_reshape(self, value, num_head) :
        b, n, emb = value.size()
        dim = emb // num_head
        return value.view(b, num_head, n, dim)
        
    def forward(self, x) :
        q = self.qkv_reshape(self.q(x), self.num_head)
        k = self.qkv_reshape(self.k(x), self.num_head)
        v = self.qkv_reshape(self.v(x), self.num_head)
        
        qk = torch.matmul(q, k.transpose(3,2))
        # torch.sqrt를 쓰니까 tensor가 아니라고 에러남
        att = F.softmax(qk / (self.embed_size ** (1/2)), dim=-1)
        att = torch.matmul(att, v)
        
        b, h, n, d = att.size()
        x = att.view(b, n, h*d)
        x = self.fc(x)
        x = self.dropout(x)
        print("DONE - MHA")
        return x

In [5]:
data = torch.rand(5, 197, 768)
mha = multi_head_attention(768, 8, dropout_rate=0.1)
y = mha(data)
y.shape

DONE - MHA


torch.Size([5, 197, 768])

In [6]:
class skip_connection(nn.Module) :
    def __init__(self, fn) :
        super(skip_connection, self).__init__()
        self.fn = fn
    
    def forward(self, x):
        skip = x
        x = self.fn(x)
        x += skip
        return x

In [7]:
data = torch.rand(5, 197, 768)
sc = skip_connection(mha)
y = sc(data)
y.shape

DONE - MHA


torch.Size([5, 197, 768])

In [8]:
class MLP(nn.Module) :
    def __init__(self, embed_size, expansion, dropout_rate):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(embed_size, embed_size*expansion)
        self.fc2 = nn.Linear(embed_size*expansion, embed_size)
        self.gelu = nn.GELU()
        self.dropout = nn.Dropout(dropout_rate)
        
    def forward(self, x) :
        x = self.fc1(x)
        x = self.gelu(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

In [9]:
data = torch.rand(5, 197, 768)
mlp = MLP(768, 3, 0.1)
y = mlp(data)
y.shape

torch.Size([5, 197, 768])

In [10]:
class EncoderBlock(nn.Module) :
    def __init__(self, 
                 embed_size, 
                 num_head, 
                 expansion, 
                 dropout_rate):
        super(EncoderBlock, self).__init__()
        
        self.skip_connection1 = skip_connection(
            nn.Sequential(
                nn.LayerNorm(embed_size),
                multi_head_attention(embed_size, num_head, dropout_rate=0.1)
            )
        )
        
        self.skip_connection2 = skip_connection(
            nn.Sequential(
                nn.LayerNorm(embed_size),
                MLP(embed_size, expansion, dropout_rate=0.1)
            )
        )
    
    def forward(self, x) :
        x = self.skip_connection1(x)
        x = self.skip_connection2(x)
        return x

In [11]:
data = torch.rand(5, 196, 672)
eb = EncoderBlock(672, 7, 4, 0.1)
y = eb(data)
y.shape

DONE - MHA


torch.Size([5, 196, 672])

In [12]:
class Classifier_Head(nn.Module) :
    def __init__(self, embed_size, num_classes):
        super(Classifier_Head, self).__init__()
        
        self.avgpool1d = nn.AdaptiveAvgPool1d((1))
        self.fc = nn.Sequential(
            nn.LayerNorm(embed_size),
            nn.Linear(embed_size, num_classes)
        )
        
    def forward(self, x) :
        # 현재 x shape 은 (batch, num_patch, embed_size)임
        # 여기서 num_patch를 기준으로 avg pool을 해준 후 fc를 통과시킴
        # avgpool2d를 사용하면 (num_path, embed_size)를 기준으로 계산이 되서
        # 결과값이 (batch, ?, ?) 이런식으로 나옴. 여기서 "?"는 pool2d의 ouput 설정 값
        # 만약 nn.AdaptiveAvgPool2d((3,3))으로하면 (batch, 3, 3)으로 나옴
        # num_patch를 기준으로 avgpool1d를 사용하기 위해 transpose로  (batch, embed_size, num_patch)로 변경해줌
        # 그 후에 avgpool1d를 적용하여 (batch, embed_size, 1)로 만들어주고 squeeze로 1을 없애줌
        # 결과적으로 transpose -> avgpool1d -> squeeze(2) -> shape:(batch, embed_size)가 됨
        x = x.transpose(2,1)
        x = self.avgpool1d(x).squeeze(2)
        x = self.fc(x)
        return x

In [13]:
data = torch.rand(5, 196, 672)
ch = Classifier_Head(672, 10)
y = ch(data)
y.shape

torch.Size([5, 10])

In [14]:
data = torch.rand(3, 3, 4)
# print(data)
avg = nn.AdaptiveAvgPool1d((1))
y = avg(data)
print(y.shape)
# print(y)
y = y.squeeze(2)
print(y.shape)
# print(y)


torch.Size([3, 3, 1])
torch.Size([3, 3])


In [15]:
class VIT(nn.Module) :
    def __init__(self, 
                 patch_size=16, 
                 img_size=224, 
                 embed_size=768, 
                 num_head = 8,
                 expansion = 4,
                 dropout_rate = 0.1,
                 encoder_depth = 12,
                 num_classes = 10) :
        super(VIT, self).__init__()

        self.PatchEmbedding = patch_embedding(patch_size, img_size, embed_size)
        self.EncoderBlocks = self.make_layers(encoder_depth, embed_size, num_head, expansion, dropout_rate)
        self.ClassifierHead = Classifier_Head(embed_size, num_classes)
        
    def make_layers(self, encoder_depth, *args):
        layers = []
        for _ in range(0, encoder_depth) :
            layers.append(EncoderBlock(*args))
        return nn.Sequential(*layers)
    
    def forward(self, x) :
        x = self.PatchEmbedding(x)
        x = self.EncoderBlocks(x)
        x = self.ClassifierHead(x)
        
        return x

In [16]:
model = VIT()
summary(model, (3,224,224),device='cpu')

DONE - MHA
DONE - MHA
DONE - MHA
DONE - MHA
DONE - MHA
DONE - MHA
DONE - MHA
DONE - MHA
DONE - MHA
DONE - MHA
DONE - MHA
DONE - MHA
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 768, 14, 14]         590,592
   patch_embedding-2             [-1, 197, 768]               0
         LayerNorm-3             [-1, 197, 768]           1,536
            Linear-4             [-1, 197, 768]         590,592
            Linear-5             [-1, 197, 768]         590,592
            Linear-6             [-1, 197, 768]         590,592
            Linear-7             [-1, 197, 768]         590,592
           Dropout-8             [-1, 197, 768]               0
multi_head_attention-9             [-1, 197, 768]               0
  skip_connection-10             [-1, 197, 768]               0
        LayerNorm-11             [-1, 197, 768]           1,536
           Linear-12            [

In [17]:
def a(z, *args) :
    print("type : ", type(args))
    return b(*args)

def b(a,b,c) :
    return a+b+c
print(a(12, 5, 3, 4))

type :  <class 'tuple'>
12
