## Enviorment Import

In [2]:
import torch
import math
from torch import nn

## 전체적인 구조
- 트랜스포머
    - 인코더
        - 멀티헤드 어텐션
        - Add & Norm
        - FPN
        - Add & Norm
    - 디코더
        - 마스크 멀티헤드 어텐션
        - Add & Norm
        - 멀티헤드 어텐션
        - Add & Norm
        - FPN
        - Add & Norm
    - 리니어 레이어


### MHA (Multi-Head Attention) 

In [3]:
class MHA(nn.Module):

    def __init__(self, dim=512, head_num=8) -> None:
        super().__init__()
        self.head_num = head_num
        self.dim = dim

        self.query_embed = nn.Linear(dim, dim)
        self.key_embed = nn.Linear(dim, dim)
        self.value_embed = nn.Linear(dim, dim)
        self.output_embed = nn.Linear(dim, dim)

    def scaled_dot_product_attention(self, query, key, value, mask:bool=False, scale:int=None) -> torch.Tensor:

        scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
        L, S = query.size(-2), key.size(-2)
        attn_bias = torch.zeros(L, S, dtype=query.dtype)
        if mask:
            temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
            attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
            attn_bias.to(query.dtype)

        attn_weight = query @ key.transpose(-2,-1) / scale_factor
        attn_weight += attn_bias # mask
        attn_weight = torch.softmax(attn_weight, dim=-1)
        attn_weight = attn_weight @ value
  
        return attn_weight

    def transform(self, x):
        batch_size = x.size(0)
        
        return x.view(batch_size, -1, self.head_num, self.dim // self.head_num).transpose(1, 2)
    
    def forward(self, query, key, value, mask:bool=False, scale:int=None) -> torch.Tensor:
        batch_size = query.size(0)
        # 순서 유지 때문에 view 후 transpose 사용

        query = self.transform(self.query_embed(query))
        key = self.transform(self.key_embed(key))
        value = self.transform(self.value_embed(value))

        output = self.scaled_dot_product_attention(query, key, value, mask, scale)
        batch_num, head_num, seq_num, hidden_num = output.size()
        output =output.transpose(1, 2).contiguous().view((batch_size, -1, hidden_num * self.head_num))
        output = self.output_embed(output)
  
        return output

In [4]:
# test scaled_dot_product_attention 
q = torch.rand((4,10,64))
k = torch.rand((4,10,64))
v = torch.rand((4,10,64))


mha = MHA(64, 8)

mha(q,k,v).shape

torch.Size([4, 10, 64])

### FFN (Feed Forward Network)

In [5]:
class FFN(nn.Module):
    def __init__(self, in_dim:int = 512, hidden_dim:int = 2048) -> None:
        super().__init__()
        self.fc1 = nn.Linear(in_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, in_dim)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        
        return x

In [6]:
a = torch.rand((4,10,64))

ffn = FFN(64, 256)
ffn(a).shape

torch.Size([4, 10, 64])

### Add&Norm

In [7]:
class AddNorm(nn.Module):
    def __init__(self, dim) -> None:
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        
    def forward(self, x1, x2):
        return self.norm(x1 + x2)

### Encoder

In [8]:
class EncoderBlock(nn.Module):
    def __init__(self, dim:int=512, hidden_dim:int=256, head_num:int=8) -> None:
        super().__init__()
        self.mha = MHA(dim=dim, head_num=head_num)
        self.ffn = FFN(in_dim=dim, hidden_dim=hidden_dim)
        self.add_norm = AddNorm(dim)
        
    def forward(self, x):
        x = self.add_norm(x, self.mha(x, x, x))
        x = self.add_norm(x, self.ffn(x))
        
        return x

In [9]:
class TransformerEncoder(nn.Module):
    def __init__(self, num_block, dim:int=512, hidden_dim:int=256, head_num:int=8) -> None:
        super().__init__()
        self.encoder = nn.Sequential(*[EncoderBlock(dim, hidden_dim, head_num) for _ in range(num_block)])
    
    def forward(self, x):
        return self.encoder(x)

In [10]:
a = torch.rand((4,10,512))

enc = TransformerEncoder(6)
enc(a).shape

torch.Size([4, 10, 512])

### Decoder

In [11]:
class DecoderBlock(nn.Module):
    def __init__(self, dim:int=512, hidden_dim:int=256, head_num:int=8) -> None:
        super().__init__()
        self.mask_mha = MHA(dim=dim, head_num=head_num)
        self.mha = MHA(dim=dim, head_num=head_num)
        self.ffn = FFN(in_dim=dim, hidden_dim=hidden_dim)
        self.add_norm = AddNorm(dim)
        
    def forward(self, x, encoder_output):
        x = self.add_norm(x, self.mask_mha(x, x, x, True))
        x = self.add_norm(x, self.mha(encoder_output, encoder_output, x))
        x = self.add_norm(x, self.ffn(x))
        
        return x

In [12]:
class TransformerDecoder(nn.Module):
    def __init__(self, num_block, dim:int=512, hidden_dim:int=256, head_num:int=8) -> None:
        super().__init__()
        self.decoder = nn.ModuleList([DecoderBlock(dim, hidden_dim, head_num) for _ in range(num_block)])
    
    def forward(self, x, encoder_output):
        for dec in self.decoder:
            x = dec(x, encoder_output)
        
        return x

In [13]:
a = torch.rand((4,10,512))

dec = TransformerDecoder(6)
dec(a, a).shape

torch.Size([4, 10, 512])

### Transformer

In [14]:
class Transformer(nn.Module):
    def __init__(self, output_dim, encoder_num, decoder_num, dim:int=512, hidden_dim:int=256, head_num:int=8) -> None:
        super().__init__()
        
        # encoder positional
        self.encoder = TransformerEncoder(encoder_num, dim,hidden_dim, head_num)
        # decoder positional
        self.decoder = TransformerDecoder(decoder_num,dim, hidden_dim, head_num)
        
        self.linear = nn.Linear(dim, output_dim)
        self.softmax = nn.Softmax(-1)
    
    def forward(self, x):
        encoder_output = self.encoder(x)
        decoder_output = self.decoder(x, encoder_output)
        
        output = self.linear(decoder_output)
        output = self.softmax(output)
        
        return output
        

In [15]:
test_input = torch.rand((4,10,512))
transformer = Transformer(10,6,6)
transformer(test_input).shape

torch.Size([4, 10, 10])